use super::{
ContinuationFrameHeader,
FirstFrameHeader,
FrameSequence,
MessageKey,
error::{MessageSeriesError, MessageSeriesStatus},
};
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum SequenceTracking {
Untracked,
Tracked,
}
#[derive(Clone, Debug)]
pub struct MessageSeries {
message_key: MessageKey,
next_sequence: Option<FrameSequence>,
sequence_tracking: SequenceTracking,
complete: bool,
expected_total: Option<usize>,
}
impl MessageSeries {
#[must_use]
pub fn from_first_frame(header: &FirstFrameHeader) -> Self {
Self {
message_key: header.message_key,
next_sequence: None,
sequence_tracking: SequenceTracking::Untracked,
expected_total: header.total_body_len,
complete: header.is_last,
}
}
#[must_use]
pub const fn message_key(&self) -> MessageKey { self.message_key }
#[must_use]
pub const fn is_complete(&self) -> bool { self.complete }
#[must_use]
pub const fn expected_total(&self) -> Option<usize> { self.expected_total }
pub fn accept_continuation(
&mut self,
header: &ContinuationFrameHeader,
) -> Result<MessageSeriesStatus, MessageSeriesError> {
if header.message_key != self.message_key {
return Err(MessageSeriesError::KeyMismatch {
expected: self.message_key,
found: header.message_key,
});
}
if self.complete {
return Err(MessageSeriesError::SeriesComplete);
}
if let Some(incoming_seq) = header.sequence {
self.validate_and_advance_sequence(incoming_seq, header.is_last)?;
} else if self.sequence_tracking == SequenceTracking::Tracked {
return Err(MessageSeriesError::MissingSequence {
key: self.message_key,
});
}
if header.is_last {
self.complete = true;
return Ok(MessageSeriesStatus::Complete);
}
Ok(MessageSeriesStatus::Incomplete)
}
fn validate_and_advance_sequence(
&mut self,
incoming: FrameSequence,
is_last: bool,
) -> Result<(), MessageSeriesError> {
match self.sequence_tracking {
SequenceTracking::Untracked => {
self.sequence_tracking = SequenceTracking::Tracked;
self.next_sequence = incoming.checked_increment();
if self.next_sequence.is_none() && !is_last {
return Err(MessageSeriesError::SequenceOverflow { last: incoming });
}
Ok(())
}
SequenceTracking::Tracked => {
let expected = self
.next_sequence
.ok_or(MessageSeriesError::SequenceOverflow { last: incoming })?;
if incoming.0 < expected.0 {
return Err(MessageSeriesError::DuplicateFrame {
key: self.message_key,
sequence: incoming,
});
}
if incoming != expected {
return Err(MessageSeriesError::SequenceMismatch {
expected,
found: incoming,
});
}
self.next_sequence = incoming.checked_increment();
if self.next_sequence.is_none() && !is_last {
return Err(MessageSeriesError::SequenceOverflow { last: incoming });
}
Ok(())
}
}
}
}
impl MessageSeries {
#[doc(hidden)]
pub fn force_next_sequence_for_tests(&mut self, next: FrameSequence) {
self.sequence_tracking = SequenceTracking::Tracked;
self.next_sequence = Some(next);
}
}