#define MS_CLASS "RTC::SCTP::InterleavedReassemblyStreams"
#include "RTC/SCTP/rx/InterleavedReassemblyStreams.hpp"
#include "Logger.hpp"
namespace RTC
{
namespace SCTP
{
InterleavedReassemblyStreams::InterleavedReassemblyStreams(
ReassemblyStreamsInterface::OnAssembledMessage onAssembledMessage)
: onAssembledMessage(std::move(onAssembledMessage))
{
MS_TRACE();
}
int32_t InterleavedReassemblyStreams::AddData(Types::UnwrappedTsn tsn, UserData data)
{
MS_TRACE();
return GetOrCreateStream(FullStreamId(data.IsUnordered(), data.GetStreamId()))
.AddData(tsn, std::move(data));
}
size_t InterleavedReassemblyStreams::HandleForwardTsn(
Types::UnwrappedTsn ,
std::span<const AnyForwardTsnChunk::SkippedStream> skippedStreams)
{
MS_TRACE();
size_t bytesRemoved = 0;
for (const auto& skippedStream : skippedStreams)
{
bytesRemoved +=
GetOrCreateStream(FullStreamId(skippedStream.unordered, skippedStream.streamId))
.EraseTo(skippedStream.mid);
}
return bytesRemoved;
}
void InterleavedReassemblyStreams::ResetStreams(std::span<const uint16_t> streamIds)
{
MS_TRACE();
if (streamIds.empty())
{
for (auto& [fullStreamId, stream] : this->streams)
{
MS_DEBUG_DEV("resetting implicit stream [streamId:%" PRIu16 "]", fullStreamId.streamId);
stream.Reset();
}
}
else
{
for (const auto streamId : streamIds)
{
MS_DEBUG_DEV("resetting explicit stream [streamId:%" PRIu16 "]", streamId);
GetOrCreateStream(FullStreamId( true, streamId)).Reset();
GetOrCreateStream(FullStreamId( false, streamId)).Reset();
}
}
}
InterleavedReassemblyStreams::Stream& InterleavedReassemblyStreams::GetOrCreateStream(
const FullStreamId& streamId)
{
MS_TRACE();
auto it = this->streams.find(streamId);
if (it == this->streams.end())
{
it = this->streams
.emplace(
std::piecewise_construct,
std::forward_as_tuple(streamId),
std::forward_as_tuple(streamId, this))
.first;
}
auto& stream = it->second;
return stream;
}
int32_t InterleavedReassemblyStreams::Stream::AddData(Types::UnwrappedTsn tsn, UserData data)
{
MS_TRACE();
MS_ASSERT(
data.IsUnordered() == this->fullStreamId.unordered,
"data.IsUnordered() != this->streamId.unordered");
MS_ASSERT(
data.GetStreamId() == this->fullStreamId.streamId,
"data.GetStreamId() != this->fullStreamId.streamId");
auto queuedBytes = static_cast<int32_t>(data.GetPayloadLength());
const Types::UnwrappedMid mid = this->midUnwrapper.Unwrap(data.GetMessageId());
const uint32_t fsn = data.GetFragmentSequenceNumber();
if (this->fullStreamId.unordered && data.IsBeginning() && data.IsEnd())
{
AssembleMessage(tsn, std::move(data));
return 0;
}
else if (!this->fullStreamId.unordered && mid == this->nextMid && data.IsBeginning() && data.IsEnd())
{
AssembleMessage(tsn, std::move(data));
this->nextMid.Increment();
return -TryToAssembleMessages();
}
const auto [unused, inserted] =
this->chunksByMid[mid].emplace(fsn, std::make_pair(tsn, std::move(data)));
if (!inserted)
{
return 0;
}
if (this->fullStreamId.unordered)
{
queuedBytes -= TryToAssembleMessage(mid);
}
else
{
if (mid == this->nextMid)
{
queuedBytes -= TryToAssembleMessages();
}
}
return queuedBytes;
}
size_t InterleavedReassemblyStreams::Stream::EraseTo(uint32_t mid)
{
MS_TRACE();
const Types::UnwrappedMid unwrappedMid = this->midUnwrapper.Unwrap(mid);
size_t removedBytes = 0;
auto it = this->chunksByMid.begin();
while (it != this->chunksByMid.end() && it->first <= unwrappedMid)
{
removedBytes += std::accumulate(
it->second.begin(),
it->second.end(),
0,
[](size_t acc, const auto& i)
{
const auto& data = i.second.second;
return acc + data.GetPayloadLength();
});
it = this->chunksByMid.erase(it);
}
if (!this->fullStreamId.unordered)
{
if (unwrappedMid >= this->nextMid)
{
this->nextMid = unwrappedMid.GetNextValue();
}
removedBytes += TryToAssembleMessages();
}
return removedBytes;
}
size_t InterleavedReassemblyStreams::Stream::TryToAssembleMessage(Types::UnwrappedMid mid)
{
MS_TRACE();
const auto it = this->chunksByMid.find(mid);
if (it == this->chunksByMid.end())
{
MS_DEBUG_DEV("trying to assemble message [mid:%" PRIu32 "]", mid.Wrap());
return 0;
}
ChunkMap& chunks = it->second;
if (!chunks.begin()->second.second.IsBeginning() || !chunks.rbegin()->second.second.IsEnd())
{
MS_DEBUG_DEV(
"cannot assemble message, missing beggining or end [mid:%" PRIu32 "]", mid.Wrap());
return 0;
}
const int64_t fnsDiff = chunks.rbegin()->first - chunks.begin()->first;
if (fnsDiff != (static_cast<int64_t>(chunks.size()) - 1))
{
MS_DEBUG_DEV(
"cannot assemble message, not all chunks exist [has:%zu, expected:%" PRIi64 "]",
chunks.size(),
fnsDiff + 1);
return 0;
}
const size_t removedBytes = AssembleMessage(chunks);
MS_DEBUG_DEV("message assembled [mid:%" PRIu32 ", removedBytes:%zu]", mid.Wrap(), removedBytes);
this->chunksByMid.erase(mid);
return removedBytes;
}
size_t InterleavedReassemblyStreams::Stream::TryToAssembleMessages()
{
MS_TRACE();
size_t removedBytes = 0;
for (;;)
{
const size_t removedBytesThisIt = TryToAssembleMessage(this->nextMid);
if (removedBytesThisIt == 0)
{
break;
}
removedBytes += removedBytesThisIt;
this->nextMid.Increment();
}
return removedBytes;
}
size_t InterleavedReassemblyStreams::Stream::AssembleMessage(Types::UnwrappedTsn tsn, UserData data)
{
MS_TRACE();
const size_t payloadLength = data.GetPayloadLength();
const auto streamId = data.GetStreamId();
const auto pip = data.GetPayloadProtocolId();
const Types::UnwrappedTsn tsns[1] = { tsn };
Message message(
streamId,
pip,
std::move(data).ReleasePayload());
this->parent.onAssembledMessage(tsns, std::move(message));
return payloadLength;
}
size_t InterleavedReassemblyStreams::Stream::AssembleMessage(ChunkMap& tsnChunks)
{
MS_TRACE();
const size_t count = tsnChunks.size();
if (count == 1)
{
return AssembleMessage(
tsnChunks.begin()->second.first, std::move(tsnChunks.begin()->second.second));
}
std::vector<Types::UnwrappedTsn> tsns;
std::vector<uint8_t> payload;
const size_t payloadLength = std::accumulate(
tsnChunks.begin(),
tsnChunks.end(),
0,
[](size_t acc, const auto& i)
{
const auto& data = i.second.second;
return acc + data.GetPayloadLength();
});
payload.reserve(payloadLength);
tsns.reserve(count);
for (auto& item : tsnChunks)
{
const Types::UnwrappedTsn tsn = item.second.first;
UserData& data = item.second.second;
tsns.push_back(tsn);
payload.insert(payload.end(), data.GetPayload().begin(), data.GetPayload().end());
}
const UserData& data = tsnChunks.begin()->second.second;
Message message(data.GetStreamId(), data.GetPayloadProtocolId(), std::move(payload));
this->parent.onAssembledMessage(tsns, std::move(message));
return payloadLength;
}
} }