#define MS_CLASS "RTC::SCTP::TransmissionControlBlock"
#include "RTC/SCTP/association/TransmissionControlBlock.hpp"
#include "Logger.hpp"
#include "MediaSoupErrors.hpp"
#include "RTC/SCTP/packet/chunks/CookieAckChunk.hpp"
#include "RTC/SCTP/packet/chunks/CookieEchoChunk.hpp"
#include "RTC/SCTP/packet/chunks/DataChunk.hpp"
#include "RTC/SCTP/packet/chunks/IDataChunk.hpp"
#include <string>
namespace RTC
{
namespace SCTP
{
alignas(4) static thread_local uint8_t PacketFactoryBuffer[65536];
TransmissionControlBlock::TransmissionControlBlock(
AssociationListenerDeferrer& associationListenerDeferrer,
const SctpOptions& sctpOptions,
SharedInterface* shared,
SendQueueInterface& sendQueue,
PacketSender& packetSender,
uint32_t localVerificationTag,
uint32_t remoteVerificationTag,
uint32_t localInitialTsn,
uint32_t remoteInitialTsn,
uint32_t remoteAdvertisedReceiverWindowCredit,
uint64_t tieTag,
const NegotiatedCapabilities& negotiatedCapabilities,
size_t maxPacketLength,
std::function<bool()> isAssociationEstablished)
: associationListenerDeferrer(associationListenerDeferrer),
sctpOptions(sctpOptions),
shared(shared),
packetSender(packetSender),
localVerificationTag(localVerificationTag),
remoteVerificationTag(remoteVerificationTag),
localInitialTsn(localInitialTsn),
remoteInitialTsn(remoteInitialTsn),
remoteAdvertisedReceiverWindowCredit(remoteAdvertisedReceiverWindowCredit),
tieTag(tieTag),
negotiatedCapabilities(negotiatedCapabilities),
maxPacketLength(maxPacketLength),
isAssociationEstablished(std::move(isAssociationEstablished)),
t3RtxTimer(this->shared->CreateBackoffTimer(
BackoffTimerHandleInterface::BackoffTimerHandleOptions{
.listener = this,
.label = "sctp-t3-rtx",
.baseTimeoutMs = sctpOptions.initialRtoMs,
.backoffAlgorithm = BackoffTimerHandleInterface::BackoffAlgorithm::EXPONENTIAL,
.maxBackoffTimeoutMs = sctpOptions.timerMaxBackoffTimeoutMs,
.maxRestarts = std::nullopt })),
delayedAckTimer(this->shared->CreateBackoffTimer(
BackoffTimerHandleInterface::BackoffTimerHandleOptions{
.listener = this,
.label = "sctp-delayed-ack",
.baseTimeoutMs = sctpOptions.delayedAckMaxTimeoutMs,
.backoffAlgorithm = BackoffTimerHandleInterface::BackoffAlgorithm::EXPONENTIAL,
.maxBackoffTimeoutMs = std::nullopt,
.maxRestarts = 0 })),
rto(sctpOptions),
txErrorCounter(sctpOptions),
dataTracker(this->delayedAckTimer.get(), remoteInitialTsn),
reassemblyQueue(
sctpOptions.maxReceiverWindowBufferSize, negotiatedCapabilities.messageInterleaving),
retransmissionQueue(
this,
this->associationListenerDeferrer,
localInitialTsn,
remoteAdvertisedReceiverWindowCredit,
sendQueue,
this->t3RtxTimer.get(),
sctpOptions,
negotiatedCapabilities.partialReliability,
negotiatedCapabilities.messageInterleaving),
streamResetHandler(
this->associationListenerDeferrer,
this->shared,
this,
std::addressof(this->dataTracker),
std::addressof(this->reassemblyQueue),
std::addressof(this->retransmissionQueue)),
heartbeatHandler(this->associationListenerDeferrer, sctpOptions, this->shared, this)
{
MS_TRACE();
sendQueue.EnableMessageInterleaving(this->negotiatedCapabilities.messageInterleaving);
}
TransmissionControlBlock::~TransmissionControlBlock()
{
MS_TRACE();
}
void TransmissionControlBlock::Dump(int indentation) const
{
MS_TRACE();
MS_DUMP_CLEAN(indentation, "<SCTP::TransmissionControlBlock>");
MS_DUMP_CLEAN(indentation, " local verification tag: %" PRIu32, this->localVerificationTag);
MS_DUMP_CLEAN(indentation, " remote verification tag: %" PRIu32, this->remoteVerificationTag);
MS_DUMP_CLEAN(indentation, " local initial tsn: %" PRIu32, this->localInitialTsn);
MS_DUMP_CLEAN(indentation, " remote initial tsn: %" PRIu32, this->remoteInitialTsn);
MS_DUMP_CLEAN(
indentation,
" remote advertised receiver window credit: %" PRIu32,
this->remoteAdvertisedReceiverWindowCredit);
MS_DUMP_CLEAN(indentation, " tie-tag: %" PRIu64, this->tieTag);
this->negotiatedCapabilities.Dump(indentation + 1);
this->rto.Dump(indentation + 1);
this->txErrorCounter.Dump(indentation + 1);
MS_DUMP_CLEAN(indentation, "</SCTP::TransmissionControlBlock>");
}
void TransmissionControlBlock::ObserveRttMs(uint64_t rttMs)
{
MS_TRACE();
#if MS_LOG_DEV_LEVEL == 3
const auto prevRtoMs = this->rto.GetRtoMs();
#endif
this->rto.ObserveRttMs(rttMs);
MS_DEBUG_DEV(
"new rtt:%" PRIu64 ", previous rto:%" PRIu64 ", new rto:%" PRIu64 ", srtt:%" PRIu64,
rttMs,
prevRtoMs,
this->rto.GetRtoMs(),
this->rto.GetSrttMs());
this->t3RtxTimer->SetBaseTimeoutMs(this->rto.GetRtoMs());
const uint64_t delayedAckTimeoutMs = std::min(
static_cast<uint64_t>(this->rto.GetRtoMs() * 0.5), this->sctpOptions.delayedAckMaxTimeoutMs);
this->delayedAckTimer->SetBaseTimeoutMs(delayedAckTimeoutMs);
}
std::unique_ptr<Packet> TransmissionControlBlock::CreatePacket() const
{
MS_TRACE();
return CreatePacketWithVerificationTag(this->remoteVerificationTag);
}
std::unique_ptr<Packet> TransmissionControlBlock::CreatePacketWithVerificationTag(
uint32_t verificationTag) const
{
MS_TRACE();
auto packet =
std::unique_ptr<Packet>{ Packet::Factory(PacketFactoryBuffer, this->maxPacketLength) };
packet->SetSourcePort(this->sctpOptions.sourcePort);
packet->SetDestinationPort(this->sctpOptions.destinationPort);
packet->SetVerificationTag(verificationTag);
return packet;
}
bool TransmissionControlBlock::SendPacket(Packet* packet)
{
MS_TRACE();
return this->packetSender.SendPacket(
packet,
!this->negotiatedCapabilities.zeroChecksum);
}
void TransmissionControlBlock::SetRemoteStateCookie(std::vector<uint8_t> remoteStateCookie)
{
MS_TRACE();
this->remoteStateCookie = std::move(remoteStateCookie);
}
void TransmissionControlBlock::ClearRemoteStateCookie()
{
MS_TRACE();
this->remoteStateCookie.reset();
}
void TransmissionControlBlock::MaySendSackChunk()
{
MS_TRACE();
if (!this->dataTracker.ShouldSendAck( false))
{
return;
}
const auto packet = CreatePacket();
this->dataTracker.AddSackSelectiveAck(packet.get(), this->reassemblyQueue.GetRemainingBytes());
SendPacket(packet.get());
}
void TransmissionControlBlock::MayAddForwardTsnChunk(Packet* packet, uint64_t nowMs)
{
MS_TRACE();
if (nowMs >= this->limitForwardTsnUntilMs && this->retransmissionQueue.ShouldSendForwardTsn(nowMs))
{
if (this->negotiatedCapabilities.messageInterleaving)
{
this->retransmissionQueue.AddIForwardTsn(packet);
}
else
{
this->retransmissionQueue.AddForwardTsn(packet);
}
this->limitForwardTsnUntilMs = nowMs + std::min(uint64_t{ 200 }, this->rto.GetSrttMs());
}
}
void TransmissionControlBlock::MaySendFastRetransmit()
{
MS_TRACE();
if (!this->retransmissionQueue.HasDataToBeFastRetransmitted())
{
return;
}
const auto packet = CreatePacket();
auto result =
this->retransmissionQueue.GetChunksForFastRetransmit(packet->GetAvailableLength());
for (auto& [tsn, data] : result)
{
if (this->negotiatedCapabilities.messageInterleaving)
{
auto* iDataChunk = packet->BuildChunkInPlace<IDataChunk>();
iDataChunk->SetTsn(tsn);
iDataChunk->SetUserData(std::move(data));
iDataChunk->Consolidate();
}
else
{
auto* dataChunk = packet->BuildChunkInPlace<DataChunk>();
dataChunk->SetTsn(tsn);
dataChunk->SetUserData(std::move(data));
dataChunk->Consolidate();
}
}
SendPacket(packet.get());
}
void TransmissionControlBlock::SendBufferedPackets(uint64_t nowMs, bool addCookieAckChunk)
{
MS_TRACE();
for (size_t packetIdx{ 0 }; packetIdx < this->sctpOptions.maxBurst; ++packetIdx)
{
const auto packet = CreatePacket();
if (packetIdx == 0)
{
if (addCookieAckChunk)
{
MS_DEBUG_DEV("adding COOKIE-ACK chunk to the packet");
const auto* cookieAckChunk = packet->BuildChunkInPlace<CookieAckChunk>();
cookieAckChunk->Consolidate();
}
if (this->remoteStateCookie.has_value())
{
if (packet->GetChunksCount() > 0)
{
MS_THROW_ERROR(
"packet must have no chunks [addCookieAckChunk:%s]",
addCookieAckChunk ? "true" : "no");
}
auto* cookieEchoChunk = packet->BuildChunkInPlace<CookieEchoChunk>();
cookieEchoChunk->SetCookie(
remoteStateCookie->data(), static_cast<uint16_t>(remoteStateCookie->size()));
cookieEchoChunk->Consolidate();
}
if (this->dataTracker.ShouldSendAck( true))
{
this->dataTracker.AddSackSelectiveAck(
packet.get(), this->reassemblyQueue.GetRemainingBytes());
}
const uint64_t nowMs = this->shared->GetTimeMs();
MayAddForwardTsnChunk(packet.get(), nowMs);
if (this->streamResetHandler.ShouldSendStreamResetRequest())
{
this->streamResetHandler.AddStreamResetRequest(packet.get());
}
}
auto chunksToSend =
this->retransmissionQueue.GetChunksToSend(nowMs, packet->GetAvailableLength());
if (!chunksToSend.empty())
{
this->heartbeatHandler.RestartTimer();
}
const bool immediateAck =
GetCwnd() < (this->sctpOptions.immediateSackUnderCwndMtus * this->sctpOptions.mtu);
for (auto& [tsn, data] : chunksToSend)
{
if (this->negotiatedCapabilities.messageInterleaving)
{
auto* iDataChunk = packet->BuildChunkInPlace<IDataChunk>();
iDataChunk->SetTsn(tsn);
iDataChunk->SetI(immediateAck);
iDataChunk->SetUserData(std::move(data));
iDataChunk->Consolidate();
}
else
{
auto* dataChunk = packet->BuildChunkInPlace<DataChunk>();
dataChunk->SetTsn(tsn);
dataChunk->SetI(immediateAck);
dataChunk->SetUserData(std::move(data));
dataChunk->Consolidate();
}
}
if (!this->packetSender.SendPacket(
packet.get(),
!negotiatedCapabilities.zeroChecksum ||
this->remoteStateCookie.has_value()))
{
break;
}
if (this->remoteStateCookie.has_value())
{
break;
}
}
}
void TransmissionControlBlock::OnT3RtxTimer(uint64_t& , bool& )
{
MS_TRACE();
const AssociationListenerDeferrer::ScopedDeferrer deferrer(this->associationListenerDeferrer);
if (this->remoteStateCookie.has_value())
{
MS_DEBUG_DEV("not retransmitting as T1-cookie is active");
}
else
{
if (IncrementTxErrorCounter("t3-rtx expired"))
{
this->retransmissionQueue.HandleT3RtxTimerExpiry();
const uint64_t nowMs = this->shared->GetTimeMs();
SendBufferedPackets(nowMs);
}
}
}
void TransmissionControlBlock::OnDelayedAckTimer(uint64_t& , bool& )
{
MS_TRACE();
const AssociationListenerDeferrer::ScopedDeferrer deferrer(this->associationListenerDeferrer);
this->dataTracker.HandleDelayedAckTimerExpiry();
MaySendSackChunk();
}
void TransmissionControlBlock::OnBackoffTimer(
BackoffTimerHandleInterface* backoffTimer, uint64_t& baseTimeoutMs, bool& stop)
{
MS_TRACE();
const auto maxRestarts = backoffTimer->GetMaxRestarts();
MS_DEBUG_TAG(
sctp,
"%s timer has expired [expìrations:%zu/%s]",
backoffTimer->GetLabel().c_str(),
backoffTimer->GetExpirationCount(),
maxRestarts ? std::to_string(maxRestarts.value()).c_str() : "Infinite");
if (backoffTimer == this->t3RtxTimer.get())
{
OnT3RtxTimer(baseTimeoutMs, stop);
}
else if (backoffTimer == this->delayedAckTimer.get())
{
OnDelayedAckTimer(baseTimeoutMs, stop);
}
}
void TransmissionControlBlock::OnRetransmissionQueueNewRttMs(uint64_t newRttMs)
{
MS_TRACE();
ObserveRttMs(newRttMs);
}
void TransmissionControlBlock::OnRetransmissionQueueClearRetransmissionCounter()
{
MS_TRACE();
this->txErrorCounter.Clear();
}
} }