#define MS_CLASS "RTC::SCTP::OutstandingData"
#include "RTC/SCTP/tx/OutstandingData.hpp"
#include "Logger.hpp"
#include "MediaSoupErrors.hpp"
#include "RTC/SCTP/packet/chunks/AnyForwardTsnChunk.hpp"
#include "Utils.hpp"
#include <map>
namespace RTC
{
namespace SCTP
{
constexpr uint8_t NumberOfNacksForRetransmission{ 3 };
OutstandingData::Item::Item(
uint32_t outgoingMessageId,
UserData data,
uint64_t timeSentMs,
uint16_t maxRetransmissions,
uint64_t expiresAtMs,
std::optional<uint64_t> lifecycleId)
: outgoingMessageId(outgoingMessageId),
data(std::move(data)),
timeSentMs(timeSentMs),
maxRetransmissions(maxRetransmissions),
expiresAtMs(expiresAtMs),
lifecycleId(lifecycleId)
{
MS_TRACE();
}
void OutstandingData::Item::Ack()
{
MS_TRACE();
if (this->lifecycle != Lifecycle::ABANDONED)
{
this->lifecycle = Lifecycle::ACTIVE;
}
this->ackState = AckState::ACKED;
}
OutstandingData::Item::NackAction OutstandingData::Item::Nack(bool retransmitNow)
{
MS_TRACE();
this->ackState = AckState::NACKED;
++this->nackCount;
if (!ShouldBeRetransmitted() && !IsAbandoned() && (retransmitNow || this->nackCount >= NumberOfNacksForRetransmission))
{
if (this->numRetransmissions < this->maxRetransmissions)
{
this->lifecycle = Lifecycle::TO_BE_RETRANSMITTED;
return NackAction::RETRANSMIT;
}
Abandon();
return NackAction::ABANDON;
}
return NackAction::NOTHING;
}
void OutstandingData::Item::MarkAsRetransmitted()
{
MS_TRACE();
this->lifecycle = Lifecycle::ACTIVE;
this->ackState = AckState::UNACKED;
this->nackCount = 0;
++this->numRetransmissions;
}
void OutstandingData::Item::Abandon()
{
MS_TRACE();
MS_ASSERT(
this->expiresAtMs != Types::ExpiresAtMsInfinite ||
this->maxRetransmissions != Types::MaxRetransmitsNoLimit,
"item should not have infinite expiration time or its retransmission times shouldn't be the maximum");
this->lifecycle = Lifecycle::ABANDONED;
}
OutstandingData::OutstandingData(
size_t dataChunkHeaderLength,
Types::UnwrappedTsn lastCumulativeTsnAck,
std::function<bool(uint16_t , uint32_t )> discardFromSendQueue)
: dataChunkHeaderLength(dataChunkHeaderLength),
lastCumulativeTsnAck(lastCumulativeTsnAck),
discardFromSendQueue(std::move(discardFromSendQueue))
{
MS_TRACE();
}
OutstandingData::AckInfo OutstandingData::HandleSack(
Types::UnwrappedTsn cumulativeTsnAck,
std::span<const SackChunk::GapAckBlock> gapAckBlocks,
bool isInFastRecovery)
{
MS_TRACE();
const bool cumulativeTsnAckAdvanced = cumulativeTsnAck > this->lastCumulativeTsnAck;
OutstandingData::AckInfo ackInfo(cumulativeTsnAck);
RemoveAcked(cumulativeTsnAck, ackInfo);
AckGapBlocks(cumulativeTsnAck, gapAckBlocks, ackInfo);
NackBetweenAckBlocks(
cumulativeTsnAck, gapAckBlocks, isInFastRecovery, cumulativeTsnAckAdvanced, ackInfo);
AssertIsConsistent();
return ackInfo;
}
std::vector<std::pair<uint32_t , UserData>> OutstandingData::GetChunksToBeFastRetransmitted(
size_t maxLength)
{
MS_TRACE();
std::vector<std::pair<uint32_t , UserData>> result =
ExtractChunksThatCanFit(this->toBeFastRetransmitted, maxLength);
if (!this->toBeFastRetransmitted.empty())
{
this->toBeRetransmitted.insert(
this->toBeFastRetransmitted.begin(), this->toBeFastRetransmitted.end());
this->toBeFastRetransmitted.clear();
}
AssertIsConsistent();
return result;
}
std::vector<std::pair<uint32_t , UserData>> OutstandingData::GetChunksToBeRetransmitted(
size_t maxLength)
{
MS_TRACE();
MS_ASSERT(this->toBeFastRetransmitted.empty(), "this->toBeFastRetransmitted is not empty");
return ExtractChunksThatCanFit(this->toBeRetransmitted, maxLength);
}
void OutstandingData::ExpireOutstandingChunks(uint64_t nowMs)
{
MS_TRACE();
std::vector<Types::UnwrappedTsn> tsnsToExpire;
Types::UnwrappedTsn tsn = this->lastCumulativeTsnAck;
for (const Item& item : this->outstandingData)
{
tsn.Increment();
if (item.IsAbandoned())
{
}
else if (item.IsNacked() && item.HasExpired(nowMs))
{
tsnsToExpire.push_back(tsn);
}
else
{
break;
}
}
for (const Types::UnwrappedTsn tsnToExpire : tsnsToExpire)
{
const Item& item = GetItem(tsnToExpire);
MS_WARN_TAG(
sctp,
"marking nacked Chunk %" PRIu32 " and message %" PRIu32 " as expired",
tsnToExpire.Wrap(),
item.GetData().GetMessageId());
AbandonAllFor(item);
}
AssertIsConsistent();
}
Types::UnwrappedTsn OutstandingData::GetHighestOutstandingTsn() const
{
MS_TRACE();
return Types::UnwrappedTsn::AddTo(this->lastCumulativeTsnAck, this->outstandingData.size());
}
std::optional<Types::UnwrappedTsn> OutstandingData::Insert(
uint32_t outgoingMessageId,
const UserData& data,
uint64_t timeSentMs,
uint16_t maxRetransmissions,
uint64_t expiresAtMs,
std::optional<uint64_t> lifecycleId)
{
MS_TRACE();
const size_t chunkLength = GetSerializedChunkLength(data);
this->unackedPayloadBytes += data.GetPayloadLength();
this->unackedPacketBytes += chunkLength;
++this->unackedItems;
const Types::UnwrappedTsn tsn = GetNextTsn();
const Item& item = this->outstandingData.emplace_back(
outgoingMessageId, data.Clone(), timeSentMs, maxRetransmissions, expiresAtMs, lifecycleId);
if (item.HasExpired(timeSentMs))
{
MS_WARN_TAG(
sctp,
"marking freshly produced Chunk %" PRIu32 " and message %" PRIu32 " as expired",
tsn.Wrap(),
item.GetData().GetMessageId());
AbandonAllFor(item);
AssertIsConsistent();
return std::nullopt;
}
AssertIsConsistent();
return tsn;
}
void OutstandingData::NackAll()
{
MS_TRACE();
Types::UnwrappedTsn tsn = this->lastCumulativeTsnAck;
std::vector<Types::UnwrappedTsn> tsnsToNack;
for (const Item& item : this->outstandingData)
{
tsn.Increment();
if (!item.IsAcked())
{
tsnsToNack.push_back(tsn);
}
}
for (const Types::UnwrappedTsn tsnToNack : tsnsToNack)
{
NackItem(
tsnToNack,
true,
false);
}
AssertIsConsistent();
}
const ForwardTsnChunk* OutstandingData::AddForwardTsn(Packet* packet) const
{
MS_TRACE();
std::map<uint16_t , uint16_t > skippedPerOrderedStream;
Types::UnwrappedTsn newCumulativeAck = this->lastCumulativeTsnAck;
Types::UnwrappedTsn tsn = this->lastCumulativeTsnAck;
for (const Item& item : this->outstandingData)
{
tsn.Increment();
if (
this->streamResetBreakpointTsns.contains(tsn) ||
(tsn != newCumulativeAck.GetNextValue()) || !item.IsAbandoned())
{
break;
}
newCumulativeAck = tsn;
if (
!item.GetData().IsUnordered() && item.GetData().GetStreamSequenceNumber() >
skippedPerOrderedStream[item.GetData().GetStreamId()])
{
skippedPerOrderedStream[item.GetData().GetStreamId()] =
item.GetData().GetStreamSequenceNumber();
}
}
auto* forwardTsnChunk = packet->BuildChunkInPlace<ForwardTsnChunk>();
forwardTsnChunk->SetNewCumulativeTsn(newCumulativeAck.Wrap());
for (const auto& [streamId, ssn] : skippedPerOrderedStream)
{
forwardTsnChunk->AddSkippedStream(AnyForwardTsnChunk::SkippedStream{ streamId, ssn });
}
forwardTsnChunk->Consolidate();
return forwardTsnChunk;
}
const IForwardTsnChunk* OutstandingData::AddIForwardTsn(Packet* packet) const
{
MS_TRACE();
std::map<std::pair<bool , uint16_t >, uint32_t > skippedPerStream;
Types::UnwrappedTsn newCumulativeAck = this->lastCumulativeTsnAck;
Types::UnwrappedTsn tsn = this->lastCumulativeTsnAck;
for (const Item& item : this->outstandingData)
{
tsn.Increment();
if (
this->streamResetBreakpointTsns.contains(tsn) ||
(tsn != newCumulativeAck.GetNextValue()) || !item.IsAbandoned())
{
break;
}
newCumulativeAck = tsn;
const std::pair<bool , uint16_t > stream =
std::make_pair(item.GetData().IsUnordered(), item.GetData().GetStreamId());
skippedPerStream[stream] = std::max(item.GetData().GetMessageId(), skippedPerStream[stream]);
}
auto* iForwardTsnChunk = packet->BuildChunkInPlace<IForwardTsnChunk>();
iForwardTsnChunk->SetNewCumulativeTsn(newCumulativeAck.Wrap());
for (const auto& [stream, mid] : skippedPerStream)
{
const uint16_t streamId = stream.second;
const bool unordered = stream.first;
iForwardTsnChunk->AddSkippedStream(
AnyForwardTsnChunk::SkippedStream{ unordered, streamId, mid });
}
iForwardTsnChunk->Consolidate();
return iForwardTsnChunk;
}
std::optional<uint64_t> OutstandingData::MeasureRtt(uint64_t nowMs, Types::UnwrappedTsn tsn) const
{
MS_TRACE();
if (tsn > this->lastCumulativeTsnAck && tsn < GetNextTsn())
{
const Item& item = GetItem(tsn);
if (!item.HasBeenRetransmitted())
{
return nowMs - item.GetTimeSentMs();
}
}
return std::nullopt;
}
bool OutstandingData::ShouldSendForwardTsn() const
{
MS_TRACE();
if (!this->outstandingData.empty())
{
return this->outstandingData.front().IsAbandoned();
}
else
{
return false;
}
}
void OutstandingData::BeginResetStreams()
{
MS_TRACE();
this->streamResetBreakpointTsns.insert(GetNextTsn());
}
#ifdef MS_TEST
std::vector<
std::pair<uint32_t , OutstandingData::State>> OutstandingData::GetChunkStatesForTesting() const
{
MS_TRACE();
std::vector<std::pair<uint32_t , State>> states;
states.emplace_back(this->lastCumulativeTsnAck.Wrap(), State::ACKED);
Types::UnwrappedTsn tsn = this->lastCumulativeTsnAck;
for (const Item& item : this->outstandingData)
{
tsn.Increment();
State state;
if (item.IsAbandoned())
{
state = State::ABANDONED;
}
else if (item.ShouldBeRetransmitted())
{
state = State::TO_BE_RETRANSMITTED;
}
else if (item.IsAcked())
{
state = State::ACKED;
}
else if (item.IsNacked())
{
state = State::NACKED;
}
else if (item.IsOutstanding())
{
state = State::IN_FLIGHT;
}
else
{
MS_THROW_ERROR("should not end here");
}
states.emplace_back(tsn.Wrap(), state);
}
return states;
}
#endif
size_t OutstandingData::GetSerializedChunkLength(const UserData& data) const
{
MS_TRACE();
return Utils::Byte::PadTo4Bytes<size_t>(this->dataChunkHeaderLength + data.GetPayloadLength());
}
OutstandingData::Item& OutstandingData::GetItem(Types::UnwrappedTsn tsn)
{
MS_TRACE();
MS_ASSERT(
tsn > this->lastCumulativeTsnAck, "tsn must be higher than this->lastCumulativeTsnAck");
MS_ASSERT(tsn < GetNextTsn(), "tsn must be higher than GetNextTsn()");
const size_t index = Types::UnwrappedTsn::Difference(tsn, this->lastCumulativeTsnAck) - 1;
MS_ASSERT(index >= 0, "index must be equal or higher than 0");
MS_ASSERT(
index < this->outstandingData.size(),
"index must be lower than this->outstandingData.size()");
return this->outstandingData[index];
}
const OutstandingData::Item& OutstandingData::GetItem(Types::UnwrappedTsn tsn) const
{
MS_TRACE();
MS_ASSERT(
tsn > this->lastCumulativeTsnAck, "tsn must be higher than this->lastCumulativeTsnAck");
MS_ASSERT(tsn < GetNextTsn(), "tsn must be higher than GetNextTsn()");
const size_t index = Types::UnwrappedTsn::Difference(tsn, this->lastCumulativeTsnAck) - 1;
MS_ASSERT(index >= 0, "index must be equal or higher than 0");
MS_ASSERT(
index < this->outstandingData.size(),
"index must be lower than this->outstandingData.size()");
return this->outstandingData[index];
}
void OutstandingData::RemoveAcked(Types::UnwrappedTsn cumulativeTsnAck, AckInfo& ackInfo)
{
MS_TRACE();
while (!this->outstandingData.empty() && this->lastCumulativeTsnAck < cumulativeTsnAck)
{
const Types::UnwrappedTsn tsn = this->lastCumulativeTsnAck.GetNextValue();
Item& item = this->outstandingData.front();
AckChunk(ackInfo, tsn, item);
if (item.GetLifecycleId().has_value())
{
MS_ASSERT(item.GetData().IsEnd(), "item.GetData().IsEnd() must be true");
if (item.IsAbandoned())
{
ackInfo.abandonedLifecycleIds.push_back(item.GetLifecycleId().value());
}
else
{
ackInfo.ackedLifecycleIds.push_back(item.GetLifecycleId().value());
}
}
this->outstandingData.pop_front();
this->lastCumulativeTsnAck.Increment();
}
this->streamResetBreakpointTsns.erase(
this->streamResetBreakpointTsns.begin(),
this->streamResetBreakpointTsns.upper_bound(cumulativeTsnAck.GetNextValue()));
}
void OutstandingData::AckGapBlocks(
Types::UnwrappedTsn cumulativeTsnAck,
std::span<const SackChunk::GapAckBlock> gapAckBlocks,
AckInfo& ackInfo)
{
MS_TRACE();
for (const auto& block : gapAckBlocks)
{
const Types::UnwrappedTsn start = Types::UnwrappedTsn::AddTo(cumulativeTsnAck, block.start);
const Types::UnwrappedTsn end = Types::UnwrappedTsn::AddTo(cumulativeTsnAck, block.end);
for (Types::UnwrappedTsn tsn = start; tsn <= end; tsn = tsn.GetNextValue())
{
if (tsn > this->lastCumulativeTsnAck && tsn < GetNextTsn())
{
Item& item = GetItem(tsn);
AckChunk(ackInfo, tsn, item);
}
}
}
}
void OutstandingData::NackBetweenAckBlocks(
Types::UnwrappedTsn cumulativeTsnAck,
std::span<const SackChunk::GapAckBlock> gapAckBlocks,
bool isInFastRecovery,
bool cumulativeTsnAckedAdvanced,
AckInfo& ackInfo)
{
MS_TRACE();
Types::UnwrappedTsn maxTsnToNack = ackInfo.highestTsnAcked;
if (isInFastRecovery && cumulativeTsnAckedAdvanced)
{
maxTsnToNack = Types::UnwrappedTsn::AddTo(
cumulativeTsnAck, gapAckBlocks.empty() ? 0 : gapAckBlocks.rbegin()->end);
}
Types::UnwrappedTsn prevBlockLastAcked = cumulativeTsnAck;
for (const auto& block : gapAckBlocks)
{
const Types::UnwrappedTsn curBlockFirstAcked =
Types::UnwrappedTsn::AddTo(cumulativeTsnAck, block.start);
for (Types::UnwrappedTsn tsn = prevBlockLastAcked.GetNextValue();
tsn < curBlockFirstAcked && tsn <= maxTsnToNack && tsn < GetNextTsn();
tsn = tsn.GetNextValue())
{
ackInfo.hasPacketLoss |= NackItem(
tsn,
false,
!isInFastRecovery);
}
prevBlockLastAcked = Types::UnwrappedTsn::AddTo(cumulativeTsnAck, block.end);
}
}
void OutstandingData::AckChunk(AckInfo& ackInfo, Types::UnwrappedTsn tsn, Item& item)
{
MS_TRACE();
if (!item.IsAcked())
{
const size_t serializedLength = GetSerializedChunkLength(item.GetData());
ackInfo.bytesAcked += serializedLength;
if (item.IsOutstanding())
{
this->unackedPayloadBytes -= item.GetData().GetPayloadLength();
this->unackedPacketBytes -= serializedLength;
--this->unackedItems;
}
if (item.ShouldBeRetransmitted())
{
MS_ASSERT(
!this->toBeFastRetransmitted.contains(tsn),
"tsn should not be present in this->toBeFastRetransmitted");
this->toBeRetransmitted.erase(tsn);
}
item.Ack();
ackInfo.highestTsnAcked = std::max(ackInfo.highestTsnAcked, tsn);
}
}
bool OutstandingData::NackItem(Types::UnwrappedTsn tsn, bool retransmitNow, bool doFastRetransmit)
{
MS_TRACE();
Item& item = GetItem(tsn);
if (item.IsAcked())
{
return false;
}
const bool wasOutstanding = item.IsOutstanding();
const Item::NackAction action = item.Nack(retransmitNow);
if (wasOutstanding && !item.IsOutstanding())
{
this->unackedPayloadBytes -= item.GetData().GetPayloadLength();
this->unackedPacketBytes -= GetSerializedChunkLength(item.GetData());
--this->unackedItems;
}
switch (action)
{
case Item::NackAction::NOTHING:
{
return false;
}
case Item::NackAction::RETRANSMIT:
{
if (doFastRetransmit)
{
this->toBeFastRetransmitted.insert(tsn);
}
else
{
this->toBeRetransmitted.insert(tsn);
}
MS_DEBUG_TAG(sctp, "tsn %" PRIu32 " marked for retransmission", tsn.Wrap());
break;
}
case Item::NackAction::ABANDON:
{
MS_DEBUG_TAG(sctp, "tsn %" PRIu32 " nacked, resulted in abandoning", tsn.Wrap());
AbandonAllFor(item);
break;
}
}
return true;
}
void OutstandingData::AbandonAllFor(const OutstandingData::Item& item)
{
MS_TRACE();
if (this->discardFromSendQueue(item.GetData().GetStreamId(), item.GetOutgoingMessageId()))
{
UserData messageEnd(
item.GetData().GetStreamId(),
item.GetData().GetStreamSequenceNumber(),
item.GetData().GetMessageId(),
item.GetData().GetFragmentSequenceNumber(),
item.GetData().GetPayloadProtocolId(),
std::vector<uint8_t>(),
false,
true,
item.GetData().IsUnordered());
const Types::UnwrappedTsn tsn = GetNextTsn();
Item& addedItem = this->outstandingData.emplace_back(
item.GetOutgoingMessageId(),
std::move(messageEnd),
0,
0,
Types::ExpiresAtMsInfinite,
std::nullopt);
addedItem.Ack();
MS_DEBUG_TAG(sctp, "adding unsent end placeholder for message at TSN %" PRIu32, tsn.Wrap());
}
Types::UnwrappedTsn tsn = this->lastCumulativeTsnAck;
for (Item& other : this->outstandingData)
{
tsn.Increment();
if (
!other.IsAbandoned() && other.GetData().GetStreamId() == item.GetData().GetStreamId() &&
other.GetOutgoingMessageId() == item.GetOutgoingMessageId())
{
MS_WARN_TAG(sctp, "marking Chunk %" PRIu32 " as abandoned", tsn.Wrap());
if (other.ShouldBeRetransmitted())
{
this->toBeFastRetransmitted.erase(tsn);
this->toBeRetransmitted.erase(tsn);
}
const bool wasOutstanding = other.IsOutstanding();
other.Abandon();
if (wasOutstanding)
{
this->unackedPayloadBytes -= other.GetData().GetPayloadLength();
this->unackedPacketBytes -= GetSerializedChunkLength(other.GetData());
--this->unackedItems;
}
}
}
}
std::vector<std::pair<uint32_t , UserData>> OutstandingData::ExtractChunksThatCanFit(
std::set<Types::UnwrappedTsn>& chunks, size_t maxLength)
{
MS_TRACE();
std::vector<std::pair<uint32_t , UserData>> result;
for (auto it = chunks.begin(); it != chunks.end();)
{
const Types::UnwrappedTsn tsn = *it;
Item& item = GetItem(tsn);
MS_ASSERT(item.ShouldBeRetransmitted(), "item should be retransmitted");
MS_ASSERT(!item.IsOutstanding(), "item should not be outstanding");
MS_ASSERT(!item.IsAbandoned(), "item should not be abandoned");
MS_ASSERT(!item.IsAcked(), "item should not be acked");
const size_t serializedLength = GetSerializedChunkLength(item.GetData());
if (serializedLength <= maxLength)
{
item.MarkAsRetransmitted();
result.emplace_back(tsn.Wrap(), item.GetData().Clone());
maxLength -= serializedLength;
this->unackedPayloadBytes += item.GetData().GetPayloadLength();
this->unackedPacketBytes += serializedLength;
++this->unackedItems;
it = chunks.erase(it);
}
else
{
++it;
}
if (maxLength <= this->dataChunkHeaderLength)
{
break;
}
}
return result;
}
void OutstandingData::AssertIsConsistent() const
{
MS_TRACE();
size_t actualUnackedPayloadBytes{ 0 };
size_t actualUnackedPacketBytes{ 0 };
size_t actualUnackedItems{ 0 };
std::set<Types::UnwrappedTsn> combinedToBeRetransmitted;
combinedToBeRetransmitted.insert(this->toBeRetransmitted.begin(), this->toBeRetransmitted.end());
combinedToBeRetransmitted.insert(
this->toBeFastRetransmitted.begin(), this->toBeFastRetransmitted.end());
std::set<Types::UnwrappedTsn> actualCombinedToBeRetransmitted;
Types::UnwrappedTsn tsn = this->lastCumulativeTsnAck;
for (const Item& item : this->outstandingData)
{
tsn.Increment();
if (item.IsOutstanding())
{
actualUnackedPayloadBytes += item.GetData().GetPayloadLength();
actualUnackedPacketBytes += GetSerializedChunkLength(item.GetData());
++actualUnackedItems;
}
if (item.ShouldBeRetransmitted())
{
actualCombinedToBeRetransmitted.insert(tsn);
}
}
MS_ASSERT(
actualUnackedPayloadBytes == this->unackedPayloadBytes,
"actualUnackedPayloadBytes != this->unackedPayloadBytes");
MS_ASSERT(
actualUnackedPacketBytes == this->unackedPacketBytes,
"actualUnackedPacketBytes != this->unackedPacketBytes");
MS_ASSERT(actualUnackedItems == this->unackedItems, "actualUnackedItems != this->unackedItems");
MS_ASSERT(
actualCombinedToBeRetransmitted == combinedToBeRetransmitted,
"actualCombinedToBeRetransmitted != combinedToBeRetransmitted");
}
} }