1use std::{
11 collections::VecDeque,
12 io::IoSliceMut,
13 num::NonZeroUsize,
14 pin::Pin,
15 sync::Arc,
16 task::{Poll, Waker},
17};
18
19use msgq::MsgQueue;
20use parking_lot::Mutex;
21use tokio::io::AsyncRead;
22use tracing::{debug, trace};
23
24#[derive(Debug, PartialEq, Eq)]
25enum UserRxMessage {
26 Payload(Payload),
27 Eof,
28 Error(String),
29}
30
31impl UserRxMessage {
32 pub fn len_bytes(&self) -> usize {
33 match &self {
34 UserRxMessage::Payload(payload) => payload.len(),
35 _ => 0,
36 }
37 }
38}
39
40#[derive(Clone, Debug, PartialEq, Eq)]
41enum OoqMessage {
42 Payload(Payload),
43 Eof,
44}
45
46impl Default for OoqMessage {
47 fn default() -> Self {
48 OoqMessage::Payload(Vec::new())
49 }
50}
51
52impl OoqMessage {
53 pub fn len_bytes(&self) -> usize {
54 match &self {
55 OoqMessage::Payload(payload) => payload.len(),
56 _ => 0,
57 }
58 }
59}
60
61mod msgq {
62 use std::collections::VecDeque;
63
64 use super::{OoqMessage, UserRxMessage};
65
66 pub struct MsgQueue {
67 queue: VecDeque<UserRxMessage>,
68 len_bytes: usize,
69 capacity: usize,
70 }
71
72 impl MsgQueue {
73 pub fn new(capacity: usize) -> Self {
74 Self {
75 queue: Default::default(),
76 len_bytes: 0,
77 capacity,
78 }
79 }
80
81 #[cfg(test)]
82 pub fn len_bytes(&self) -> usize {
83 self.len_bytes
84 }
85
86 pub fn window(&self) -> usize {
87 self.capacity.saturating_sub(self.len_bytes)
88 }
89
90 #[cfg(test)]
91 pub fn is_full(&self) -> bool {
92 self.len_bytes >= self.capacity
93 }
94
95 pub fn pop_front(&mut self) -> Option<UserRxMessage> {
96 let msg = self.queue.pop_front()?;
97 self.len_bytes -= msg.len_bytes();
98 Some(msg)
99 }
100
101 pub fn try_push_back(&mut self, msg: OoqMessage) -> Result<(), OoqMessage> {
102 let len = msg.len_bytes();
103 if self.capacity - self.len_bytes < len {
104 return Err(msg);
105 }
106 self.queue.push_back(match msg {
107 OoqMessage::Payload(payload) => UserRxMessage::Payload(payload),
108 OoqMessage::Eof => UserRxMessage::Eof,
109 });
110 self.len_bytes += len;
111 Ok(())
112 }
113
114 pub(crate) fn push_back(&mut self, msg: UserRxMessage) {
115 self.len_bytes += msg.len_bytes();
116 self.queue.push_back(msg);
117 }
118 }
119}
120
121use crate::{
122 Error, Payload,
123 message::UtpMessage,
124 raw::{Type, selective_ack::SelectiveAck},
125 utils::update_optional_waker,
126};
127
128pub struct UtpStreamReadHalf {
129 current: Option<BeingRead>,
130 is_eof: bool,
131 shared: Arc<UserRxShared>,
132}
133
134impl UtpStreamReadHalf {
135 pub fn poll_read_vectored(
136 mut self: Pin<&mut Self>,
137 cx: &mut std::task::Context<'_>,
138 mut bufs: &mut [IoSliceMut<'_>],
139 ) -> Poll<std::io::Result<usize>> {
140 let mut written = 0usize;
141 let mut dispatcher_dead = false;
142
143 while let Some(current_buf) = bufs.first_mut() {
144 if current_buf.is_empty() {
145 bufs = &mut bufs[1..];
146 continue;
147 }
148 if let Some(current) = self.current.as_mut() {
150 let payload = ¤t.payload[current.offset..];
151 if payload.is_empty() {
152 return Poll::Ready(Err(std::io::Error::other(
153 "bug in UtpStreamReadHalf: payload is empty",
154 )));
155 }
156
157 let len = current_buf.len().min(payload.len());
158 current_buf[..len].copy_from_slice(&payload[..len]);
159 current_buf.advance(len);
160
161 written += len;
162 current.offset += len;
163 if current.offset == current.payload.len() {
164 self.current = None;
165 }
166 continue;
167 }
168
169 if self.is_eof {
170 break;
171 }
172
173 let mut g = self.shared.locked.lock();
174 if let Some(msg) = g.queue.pop_front() {
175 match msg {
176 UserRxMessage::Eof => {
177 drop(g);
178 self.is_eof = true;
179 break;
180 }
181 UserRxMessage::Payload(payload) => {
182 drop(g);
183 self.current = Some(BeingRead { payload, offset: 0 })
184 }
185 UserRxMessage::Error(msg) => {
186 return Poll::Ready(Err(std::io::Error::other(msg)));
187 }
188 }
189 } else {
190 if g.vsock_closed {
191 dispatcher_dead = true;
192 } else {
193 update_optional_waker(&mut g.reader_waker, cx);
194 }
195 break;
196 }
197 }
198
199 if written > 0 {
200 let mut g = self.shared.locked.lock();
201 let waker = g.dispatcher_waker.take();
202 drop(g);
203 if let Some(waker) = waker {
204 waker.wake();
205 }
206 return Poll::Ready(Ok(written));
207 }
208
209 if self.is_eof {
210 return Poll::Ready(Ok(0));
211 }
212
213 if dispatcher_dead {
214 return Poll::Ready(Err(std::io::Error::other("dispatcher dead")));
215 }
216
217 Poll::Pending
218 }
219
220 #[cfg(test)]
221 pub async fn read_all_available(&mut self) -> std::io::Result<Vec<u8>> {
222 let mut buf = vec![0u8; 2 * 1024 * 1024];
223 let mut offset = 0;
224 let mut g = self.shared.locked.lock();
225 while let Some(m) = g.queue.pop_front() {
226 match m {
227 UserRxMessage::Payload(payload) => {
228 buf[offset..offset + payload.len()].copy_from_slice(&payload);
229 offset += payload.len();
230 }
231 UserRxMessage::Eof => {
232 break;
233 }
234 UserRxMessage::Error(e) => return Err(std::io::Error::other(e)),
235 }
236 }
237 buf.truncate(offset);
238 Ok(buf)
239 }
240}
241
242impl AsyncRead for UtpStreamReadHalf {
243 fn poll_read(
244 self: Pin<&mut Self>,
245 cx: &mut std::task::Context<'_>,
246 buf: &mut tokio::io::ReadBuf<'_>,
247 ) -> Poll<std::io::Result<()>> {
248 let mut iovecs = [IoSliceMut::new(buf.initialize_unfilled())];
249 let len = std::task::ready!(self.poll_read_vectored(cx, &mut iovecs)?);
250 buf.advance(len);
251 Poll::Ready(Ok(()))
252 }
253}
254
255struct BeingRead {
256 payload: Payload,
257 offset: usize,
258}
259
260struct UserRxSharedLocked {
261 reader_dropped: bool,
262 vsock_closed: bool,
263 queue: MsgQueue,
264 dispatcher_waker: Option<Waker>,
265 reader_waker: Option<Waker>,
266}
267
268struct UserRxShared {
269 locked: Mutex<UserRxSharedLocked>,
270}
271
272impl Drop for UtpStreamReadHalf {
273 fn drop(&mut self) {
274 let mut g = self.shared.locked.lock();
275 g.reader_dropped = true;
276 let waker = g.dispatcher_waker.take();
277 drop(g);
278 if let Some(waker) = waker {
279 waker.wake();
280 }
281 }
282}
283
284impl Drop for UserRx {
285 fn drop(&mut self) {
286 self.mark_vsock_closed();
287 }
288}
289
290impl UserRxShared {
291 #[cfg(test)]
292 pub fn is_full_test(&self) -> bool {
293 self.locked.lock().queue.is_full()
294 }
295}
296
297pub struct UserRx {
298 shared: Arc<UserRxShared>,
299 ooq: OutOfOrderQueue,
300 max_incoming_payload: NonZeroUsize,
301 last_remaining_rx_window: usize,
302}
303
304impl UserRx {
305 pub fn build(
306 max_rx_bytes: NonZeroUsize,
307 max_incoming_payload: NonZeroUsize,
308 ) -> (UserRx, UtpStreamReadHalf) {
309 let shared = Arc::new(UserRxShared {
310 locked: Mutex::new(UserRxSharedLocked {
311 dispatcher_waker: None,
312 reader_waker: None,
313 queue: MsgQueue::new(max_rx_bytes.get()),
314 reader_dropped: false,
315 vsock_closed: false,
316 }),
317 });
318 let read_half = UtpStreamReadHalf {
319 current: None,
320 shared: shared.clone(),
321 is_eof: false,
322 };
323 let ooq_capacity = max_rx_bytes.get() / max_incoming_payload.get();
324 let out_of_order_queue = OutOfOrderQueue::new(
325 NonZeroUsize::new(ooq_capacity).unwrap_or_else(|| NonZeroUsize::new(64).unwrap()),
326 );
327 let write_half = UserRx {
328 shared,
329 ooq: out_of_order_queue,
330 max_incoming_payload,
331 last_remaining_rx_window: max_rx_bytes.get(),
332 };
333 (write_half, read_half)
334 }
335
336 pub fn is_reader_dropped(&self) -> bool {
337 self.shared.locked.lock().reader_dropped
338 }
339
340 pub fn remaining_rx_window(&self) -> usize {
343 if self.is_reader_dropped() {
344 0
345 } else {
346 self.last_remaining_rx_window
347 .saturating_sub(self.ooq.stored_bytes())
348 }
349 }
350
351 pub fn mark_vsock_closed(&self) {
353 let mut g = self.shared.locked.lock();
354 if !g.vsock_closed {
355 trace!("user_rx: marking vsock closed");
356 g.vsock_closed = true;
357 let waker = g.reader_waker.take();
358 drop(g);
359 if let Some(waker) = waker {
360 waker.wake();
361 }
362 }
363 }
364
365 pub fn flush(&mut self, cx: &mut std::task::Context<'_>) -> crate::Result<usize> {
368 let filled_front_bytes: usize = self.ooq.filled_front_bytes();
369 let mut remaining_rx_window = {
370 let mut g = self.shared.locked.lock();
371 let remaining_window = g.queue.window();
372 if remaining_window.saturating_sub(filled_front_bytes) < self.max_incoming_payload.get()
373 {
374 update_optional_waker(&mut g.dispatcher_waker, cx);
375 }
376 remaining_window
377 };
378
379 let mut flushed_bytes = 0;
381 let mut flushed_packets = 0;
382
383 while let Some(len) = self.ooq.send_front_if_fits(remaining_rx_window, |msg| {
384 let mut g = self.shared.locked.lock();
385 if g.reader_dropped {
386 debug_every_ms!(5000, "reader is dead, could not send UtpMesage to it");
387 return Err(msg);
388 }
389 g.queue.try_push_back(msg).unwrap();
390 Ok(())
391 }) {
392 flushed_bytes += len;
393 remaining_rx_window -= len;
394 flushed_packets += 1;
395 }
396
397 if flushed_bytes > 0 {
398 let waker = self.shared.locked.lock().reader_waker.take();
399 if let Some(w) = waker {
400 w.wake();
401 }
402 trace!(
403 packets = flushed_packets,
404 bytes = flushed_bytes,
405 "flushed from out-of-order user RX"
406 );
407 }
408
409 if self.ooq.filled_front > 0 {
410 trace!(
411 flushed_bytes,
412 flushed_packets,
413 out_of_order_filled_front = self.ooq.filled_front,
414 remaining_rx_window,
415 "did not flush everything"
416 );
417 }
418
419 self.last_remaining_rx_window = remaining_rx_window;
420 Ok(flushed_bytes)
421 }
422
423 pub fn enqueue_error(&self, msg: String) {
425 let mut g = self.shared.locked.lock();
426 g.queue.push_back(UserRxMessage::Error(msg));
427 let waker = g.reader_waker.take();
428 if let Some(waker) = waker {
429 drop(g);
430 waker.wake();
431 }
432 }
433
434 pub fn selective_ack(&self) -> Option<SelectiveAck> {
436 self.ooq.selective_ack()
437 }
438
439 #[cfg(test)]
440 pub fn len_test(&self) -> usize {
441 self.shared.locked.lock().queue.len_bytes()
442 }
443
444 pub fn assembler_empty(&self) -> bool {
446 self.ooq.is_empty()
447 }
448
449 #[cfg(test)]
451 pub fn assembler_packets(&self) -> usize {
452 self.ooq.stored_packets()
453 }
454
455 pub fn add_remove(
458 &mut self,
459 cx: &mut std::task::Context<'_>,
460 msg: UtpMessage,
461 offset: usize,
462 ) -> crate::Result<AssemblerAddRemoveResult> {
463 match self.ooq.add_remove(msg, offset)? {
464 res @ AssemblerAddRemoveResult::Consumed {
465 sequence_numbers, ..
466 } if sequence_numbers > 0 && self.ooq.is_full() => {
467 self.flush(cx)?;
468 Ok(res)
469 }
470 res => Ok(res),
471 }
472 }
473
474 #[cfg(test)]
475 async fn add_remove_test(
476 &mut self,
477 msg: UtpMessage,
478 offset: usize,
479 ) -> crate::Result<AssemblerAddRemoveResult> {
480 let mut msg = Some(msg);
481 let msg = &mut msg;
482 std::future::poll_fn(move |cx| {
483 let res = self.add_remove(cx, msg.take().unwrap(), offset);
484 Poll::Ready(res)
485 })
486 .await
487 }
488
489 #[cfg(test)]
490 pub fn is_flush_waker_registered(&self) -> bool {
491 self.shared.locked.lock().dispatcher_waker.is_some()
492 }
493
494 #[cfg(test)]
495 fn enqueue_test(&self, msg: UserRxMessage) {
496 let mut g = self.shared.locked.lock();
497 g.queue.push_back(msg);
498 }
499}
500
501pub struct OutOfOrderQueue {
502 data: VecDeque<OoqMessage>,
503 filled_front: usize,
504 len: usize,
505 len_bytes: usize,
506 capacity: usize,
507}
508
509#[derive(Debug, PartialEq, Eq)]
510pub enum AssemblerAddRemoveResult {
511 Consumed {
512 sequence_numbers: usize,
513 bytes: usize,
514 },
515 AlreadyPresent,
516 Unavailable(UtpMessage),
517}
518
519fn ooq_slot_is_default(slot: &OoqMessage) -> bool {
520 match slot {
521 OoqMessage::Payload(payload) => payload.is_empty(),
522 OoqMessage::Eof => false,
523 }
524}
525
526impl OutOfOrderQueue {
527 pub fn new(capacity: NonZeroUsize) -> Self {
528 Self {
529 data: VecDeque::from(vec![Default::default(); capacity.get()]),
530 filled_front: 0,
531 len: 0,
532 len_bytes: 0,
533 capacity: capacity.get(),
534 }
535 }
536
537 fn filled_front_bytes(&self) -> usize {
538 self.data
539 .iter()
540 .take(self.filled_front)
541 .map(|m| m.len_bytes())
542 .sum()
543 }
544
545 fn send_front_if_fits(
546 &mut self,
547 window: usize,
548 send_fn: impl FnOnce(OoqMessage) -> Result<(), OoqMessage>,
549 ) -> Option<usize> {
550 if self.filled_front == 0 {
551 return None;
552 }
553 if self.data[0].len_bytes() > window {
554 return None;
555 }
556 let msg = self.data.pop_front().unwrap();
557 let len = msg.len_bytes();
558 match send_fn(msg) {
559 Ok(()) => {}
560 Err(msg) => {
561 self.data.push_front(msg);
562 return None;
563 }
564 }
565 self.filled_front -= 1;
566 self.len -= 1;
567 self.len_bytes -= len;
568 self.data.push_back(Default::default());
569 Some(len)
570 }
571
572 pub fn is_empty(&self) -> bool {
573 self.filled_front == self.len
574 }
575
576 pub fn is_full(&self) -> bool {
577 self.len == self.capacity
578 }
579
580 #[cfg(test)]
581 fn filled_front(&self) -> usize {
582 self.filled_front
583 }
584
585 #[cfg(test)]
586 fn stored_packets(&self) -> usize {
587 self.len
588 }
589
590 fn stored_bytes(&self) -> usize {
591 self.len_bytes
592 }
593
594 #[cfg(test)]
595 fn debug_string(&self, with_data: bool) -> impl std::fmt::Display + '_ {
596 struct D<'a> {
597 q: &'a OutOfOrderQueue,
598 with_data: bool,
599 }
600 impl std::fmt::Display for D<'_> {
601 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
602 write!(
603 f,
604 "len={}, len_bytes={}",
605 self.q.stored_packets(),
606 self.q.stored_bytes(),
607 )?;
608
609 if !self.with_data {
610 return Ok(());
611 }
612
613 write!(f, ", queue={:?}", self.q.data)?;
614 Ok(())
615 }
616 }
617 D { q: self, with_data }
618 }
619
620 pub fn selective_ack(&self) -> Option<SelectiveAck> {
621 if self.is_empty() {
622 return None;
623 }
624
625 let start = self.filled_front + 1;
626 if start >= self.data.len() {
627 return None;
628 }
629 let unacked = self
630 .data
631 .range(start..)
632 .enumerate()
633 .filter_map(|(idx, data)| {
634 if ooq_slot_is_default(data) {
635 None
636 } else {
637 Some(idx)
638 }
639 });
640
641 Some(SelectiveAck::new(unacked))
642 }
643
644 pub fn add_remove(
645 &mut self,
646 msg: UtpMessage,
647 offset: usize,
648 ) -> crate::Result<AssemblerAddRemoveResult> {
649 if self.is_full() {
650 debug!(offset, "assembler buffer full");
651 return Ok(AssemblerAddRemoveResult::Unavailable(msg));
652 }
653
654 let effective_offset = offset + self.filled_front;
655
656 if effective_offset >= self.data.len() {
657 trace!(
658 offset,
659 self.filled_front, effective_offset, "message is past assembler's window"
660 );
661 return Ok(AssemblerAddRemoveResult::Unavailable(msg));
662 }
663
664 let msg = match msg.header.htype {
665 Type::ST_DATA if msg.payload().is_empty() => return Err(Error::ZeroPayloadStData),
666 Type::ST_DATA => OoqMessage::Payload(msg.data),
667 Type::ST_FIN => OoqMessage::Eof,
668 _ => return Err(Error::BugInvalidMessageExpectedStDataOrFin),
669 };
670
671 let slot = self
672 .data
673 .get_mut(effective_offset)
674 .ok_or(Error::BugAssemblerMissingSlot(effective_offset))?;
675 if !ooq_slot_is_default(slot) {
676 return Ok(AssemblerAddRemoveResult::AlreadyPresent);
677 }
678
679 self.len += 1;
680 self.len_bytes += msg.len_bytes();
681 *slot = msg;
682
683 let range = self.filled_front..self.data.len();
684 let (consumed_segments, consumed_bytes) = self
686 .data
687 .range(range)
688 .take_while(|msg| !ooq_slot_is_default(msg))
689 .fold((0, 0), |mut state, msg| {
690 state.0 += 1;
691 state.1 += msg.len_bytes();
692 state
693 });
694 self.filled_front += consumed_segments;
695 Ok(AssemblerAddRemoveResult::Consumed {
696 sequence_numbers: consumed_segments,
697 bytes: consumed_bytes,
698 })
699 }
700}
701
702#[cfg(test)]
703mod tests {
704 use std::{future::poll_fn, num::NonZeroUsize, task::Poll};
705
706 use tokio::io::AsyncReadExt;
707 use tracing::trace;
708
709 use crate::{
710 message::UtpMessage,
711 stream_rx::{AssemblerAddRemoveResult, OutOfOrderQueue, UserRxMessage},
712 test_util::setup_test_logging,
713 };
714
715 use super::{UserRx, UtpStreamReadHalf};
716
717 fn msg(seq_nr: u16, payload: &[u8]) -> UtpMessage {
718 UtpMessage::new_test(
719 crate::raw::UtpHeader {
720 htype: crate::raw::Type::ST_DATA,
721 seq_nr: seq_nr.into(),
722 ..Default::default()
723 },
724 payload,
725 )
726 }
727
728 fn user_rx(capacity_bytes: usize) -> (UserRx, UtpStreamReadHalf) {
729 UserRx::build(
730 NonZeroUsize::new(capacity_bytes).unwrap(),
731 NonZeroUsize::new(1500).unwrap(),
732 )
733 }
734
735 #[test]
736 fn test_asm_add_one_in_order() {
737 let mut asm = OutOfOrderQueue::new(NonZeroUsize::new(2).unwrap());
738 assert_eq!(
739 asm.add_remove(msg(0, b"a"), 0).unwrap(),
740 AssemblerAddRemoveResult::Consumed {
741 sequence_numbers: 1,
742 bytes: 1
743 }
744 );
745 assert_eq!(asm.stored_packets(), 1);
746 assert_eq!(asm.stored_bytes(), 1);
747 assert_eq!(asm.filled_front(), 1);
748 }
749
750 #[test]
751 fn test_asm_add_one_out_of_order() {
752 let mut asm = OutOfOrderQueue::new(NonZeroUsize::new(2).unwrap());
753 assert_eq!(
754 asm.add_remove(msg(100, b"a"), 1).unwrap(),
755 AssemblerAddRemoveResult::Consumed {
756 sequence_numbers: 0,
757 bytes: 0
758 }
759 );
760 assert_eq!(asm.stored_packets(), 1);
761 assert_eq!(asm.stored_bytes(), 1);
762 assert_eq!(asm.filled_front(), 0);
763 }
764
765 #[tokio::test]
766 async fn test_asm_channel_full_asm_empty() {
767 setup_test_logging();
768 let (mut user_rx, _read) = user_rx(1);
769 let msg = msg(0, b"a");
770
771 user_rx.enqueue_test(UserRxMessage::Payload(b"a".to_vec()));
773
774 assert!(user_rx.shared.is_full_test());
775
776 assert_eq!(
777 user_rx.add_remove_test(msg.clone(), 0).await.unwrap(),
778 AssemblerAddRemoveResult::Consumed {
779 sequence_numbers: 1,
780 bytes: 1
781 }
782 );
783 assert_eq!(user_rx.ooq.stored_packets(), 1);
784 assert_eq!(user_rx.ooq.stored_bytes(), 1);
785 assert_eq!(user_rx.ooq.filled_front(), 1);
786 }
787
788 #[tokio::test]
789 async fn test_asm_channel_full_asm_not_empty() {
790 let (mut user_rx, _read) = user_rx(1);
791 let msg = msg(0, b"a");
792
793 user_rx.enqueue_test(UserRxMessage::Payload(msg.data.clone()));
795
796 assert_eq!(
797 user_rx.add_remove_test(msg.clone(), 1).await.unwrap(),
798 AssemblerAddRemoveResult::Consumed {
799 sequence_numbers: 0,
800 bytes: 0
801 }
802 );
803
804 assert_eq!(user_rx.ooq.stored_packets(), 1);
805 assert_eq!(user_rx.ooq.stored_bytes(), 1);
806 assert_eq!(user_rx.ooq.filled_front(), 0);
807
808 assert_eq!(
809 user_rx.add_remove_test(msg.clone(), 0).await.unwrap(),
810 AssemblerAddRemoveResult::Consumed {
811 sequence_numbers: 2,
812 bytes: 2
813 }
814 );
815 assert_eq!(user_rx.ooq.stored_packets(), 2);
816 assert_eq!(user_rx.ooq.stored_bytes(), 2);
817 assert_eq!(user_rx.ooq.filled_front(), 2);
818 }
819
820 #[tokio::test]
821 async fn test_asm_out_of_order() {
822 setup_test_logging();
823
824 let (mut user_rx, mut read) = user_rx(100);
825
826 let msg_0 = msg(0, b"hello");
827 let msg_1 = msg(1, b"world");
828 let msg_2 = msg(2, b"test");
829
830 assert_eq!(
831 user_rx.add_remove_test(msg_1.clone(), 1).await.unwrap(),
832 AssemblerAddRemoveResult::Consumed {
833 sequence_numbers: 0,
834 bytes: 0
835 }
836 );
837 trace!(asm=%user_rx.ooq.debug_string(true));
838 assert_eq!(user_rx.ooq.stored_packets(), 1);
839
840 assert_eq!(
841 user_rx.add_remove_test(msg_2.clone(), 2).await.unwrap(),
842 AssemblerAddRemoveResult::Consumed {
843 sequence_numbers: 0,
844 bytes: 0
845 }
846 );
847 trace!(asm=%user_rx.ooq.debug_string(true));
848
849 assert_eq!(
850 user_rx.add_remove_test(msg_0.clone(), 0).await.unwrap(),
851 AssemblerAddRemoveResult::Consumed {
852 sequence_numbers: 3,
853 bytes: 14
854 }
855 );
856 trace!(asm=%user_rx.ooq.debug_string(true));
857 assert_eq!(user_rx.ooq.stored_packets(), 3);
858 poll_fn(|cx| {
859 assert_eq!(user_rx.flush(cx).unwrap(), 14);
860 Poll::Ready(())
861 })
862 .await;
863 assert_eq!(user_rx.ooq.stored_packets(), 0);
864
865 let mut buf = [0u8; 1024];
866 let sz = read.read(&mut buf).await.unwrap();
867 assert_eq!(std::str::from_utf8(&buf[..sz]), Ok("helloworldtest"));
868 }
869
870 #[tokio::test]
871 async fn test_asm_inorder() {
872 setup_test_logging();
873 let (mut user_rx, mut read) = user_rx(100);
874
875 let msg_0 = msg(0, b"hello");
876 let msg_1 = msg(1, b"world");
877 let msg_2 = msg(2, b"test");
878
879 assert_eq!(
880 user_rx.add_remove_test(msg_0.clone(), 0).await.unwrap(),
881 AssemblerAddRemoveResult::Consumed {
882 sequence_numbers: 1,
883 bytes: 5
884 }
885 );
886 trace!(asm=%user_rx.ooq.debug_string(true));
887 assert_eq!(user_rx.ooq.stored_packets(), 1);
888
889 assert_eq!(
890 user_rx.add_remove_test(msg_1.clone(), 0).await.unwrap(),
891 AssemblerAddRemoveResult::Consumed {
892 sequence_numbers: 1,
893 bytes: 5
894 }
895 );
896 trace!(asm=%user_rx.ooq.debug_string(true));
897
898 assert_eq!(
899 user_rx.add_remove_test(msg_2.clone(), 0).await.unwrap(),
900 AssemblerAddRemoveResult::Consumed {
901 sequence_numbers: 1,
902 bytes: 4
903 }
904 );
905 trace!(asm=%user_rx.ooq.debug_string(true));
906 assert_eq!(user_rx.ooq.stored_packets(), 3);
907
908 poll_fn(|cx| {
909 assert_eq!(user_rx.flush(cx).unwrap(), 14);
910 Poll::Ready(())
911 })
912 .await;
913
914 let mut buf = [0u8; 1024];
915 let sz = read.read(&mut buf).await.unwrap();
916 assert_eq!(std::str::from_utf8(&buf[..sz]), Ok("helloworldtest"));
917 }
918
919 #[test]
920 fn test_asm_write_out_of_bounds() {
921 setup_test_logging();
922
923 let mut asm = OutOfOrderQueue::new(NonZeroUsize::new(3).unwrap());
924
925 let msg_2 = msg(2, b"test");
926 let msg_3 = msg(3, b"test");
927
928 assert_eq!(
929 asm.add_remove(msg_2.clone(), 2).unwrap(),
930 AssemblerAddRemoveResult::Consumed {
931 sequence_numbers: 0,
932 bytes: 0
933 }
934 );
935 trace!(asm=%asm.debug_string(true));
936 assert_eq!(asm.stored_packets(), 1);
937
938 assert_eq!(
940 asm.add_remove(msg_3.clone(), 3).unwrap(),
941 AssemblerAddRemoveResult::Unavailable(msg_3)
942 );
943 trace!(asm=%asm.debug_string(true));
944 assert_eq!(asm.stored_packets(), 1);
945 }
946
947 #[test]
948 fn test_asm_duplicate_msg_ignored() {
949 setup_test_logging();
950
951 let mut asm = OutOfOrderQueue::new(NonZeroUsize::new(10).unwrap());
952 let msg_2 = msg(2, b"test");
953 assert_eq!(
954 asm.add_remove(msg_2, 2).unwrap(),
955 AssemblerAddRemoveResult::Consumed {
956 sequence_numbers: 0,
957 bytes: 0
958 }
959 );
960
961 let msg_2 = msg(2, b"test");
962 assert_eq!(
963 asm.add_remove(msg_2, 2).unwrap(),
964 AssemblerAddRemoveResult::AlreadyPresent
965 );
966 }
967}