#include "sctptransport.hpp"
#include "dtlstransport.hpp"
#include "internals.hpp"
#include "logcounter.hpp"
#include "utils.hpp"
#include <algorithm>
#include <chrono>
#include <cstdarg>
#include <cstdio>
#include <exception>
#include <iostream>
#include <limits>
#include <shared_mutex>
#include <thread>
#include <unordered_set>
#include <vector>
#define USE_PMTUD 0
using namespace std::chrono_literals;
using namespace std::chrono;
namespace rtc::impl {
using utils::to_uint16;
using utils::to_uint32;
static LogCounter COUNTER_UNKNOWN_PPID(plog::warning,
"Number of SCTP packets received with an unknown PPID");
class SctpTransport::InstancesSet {
public:
void insert(SctpTransport *instance) {
std::unique_lock lock(mMutex);
mSet.insert(instance);
}
void erase(SctpTransport *instance) {
std::unique_lock lock(mMutex);
mSet.erase(instance);
}
using shared_lock = std::shared_lock<std::shared_mutex>;
optional<shared_lock> lock(SctpTransport *instance) noexcept {
shared_lock lock(mMutex);
return mSet.find(instance) != mSet.end() ? std::make_optional(std::move(lock)) : nullopt;
}
private:
std::unordered_set<SctpTransport *> mSet;
std::shared_mutex mMutex;
};
SctpTransport::InstancesSet *SctpTransport::Instances = new InstancesSet;
void SctpTransport::Init() {
usrsctp_init(0, SctpTransport::WriteCallback, SctpTransport::DebugCallback);
usrsctp_sysctl_set_sctp_pr_enable(1); usrsctp_sysctl_set_sctp_ecn_enable(0); #ifndef SCTP_ACCEPT_ZERO_CHECKSUM
usrsctp_enable_crc32c_offload(); #endif
#ifdef SCTP_DEBUG
usrsctp_sysctl_set_sctp_debug_on(SCTP_DEBUG_ALL);
#endif
}
void SctpTransport::SetSettings(const SctpSettings &s) {
usrsctp_sysctl_set_sctp_recvspace(to_uint32(s.recvBufferSize.value_or(1024 * 1024)));
usrsctp_sysctl_set_sctp_sendspace(to_uint32(s.sendBufferSize.value_or(1024 * 1024)));
usrsctp_sysctl_set_sctp_max_chunks_on_queue(to_uint32(s.maxChunksOnQueue.value_or(10 * 1024)));
usrsctp_sysctl_set_sctp_initial_cwnd(to_uint32(s.initialCongestionWindow.value_or(10)));
usrsctp_sysctl_set_sctp_max_burst_default(to_uint32(s.maxBurst.value_or(10)));
usrsctp_sysctl_set_sctp_default_cc_module(to_uint32(s.congestionControlModule.value_or(0)));
usrsctp_sysctl_set_sctp_delayed_sack_time_default(
to_uint32(s.delayedSackTime.value_or(20ms).count()));
usrsctp_sysctl_set_sctp_rto_min_default(
to_uint32(s.minRetransmitTimeout.value_or(200ms).count()));
usrsctp_sysctl_set_sctp_rto_max_default(
to_uint32(s.maxRetransmitTimeout.value_or(10000ms).count()));
usrsctp_sysctl_set_sctp_init_rto_max_default(
to_uint32(s.maxRetransmitTimeout.value_or(10000ms).count()));
usrsctp_sysctl_set_sctp_rto_initial_default(
to_uint32(s.initialRetransmitTimeout.value_or(1000ms).count()));
auto maxRtx = to_uint32(s.maxRetransmitAttempts.value_or(5));
usrsctp_sysctl_set_sctp_init_rtx_max_default(maxRtx);
usrsctp_sysctl_set_sctp_assoc_rtx_max_default(maxRtx);
usrsctp_sysctl_set_sctp_path_rtx_max_default(maxRtx);
usrsctp_sysctl_set_sctp_heartbeat_interval_default(
to_uint32(s.heartbeatInterval.value_or(10000ms).count()));
}
void SctpTransport::Cleanup() {
while (usrsctp_finish())
std::this_thread::sleep_for(100ms);
}
SctpTransport::SctpTransport(shared_ptr<Transport> lower, const Configuration &config, Ports ports,
message_callback recvCallback, amount_callback bufferedAmountCallback,
state_callback stateChangeCallback)
: Transport(lower, std::move(stateChangeCallback)),
mMaxMessageSize(config.maxMessageSize.value_or(DEFAULT_LOCAL_MAX_MESSAGE_SIZE)),
mPorts(std::move(ports)), mSendQueue(0, message_size_func),
mBufferedAmountCallback(std::move(bufferedAmountCallback)) {
onRecv(std::move(recvCallback));
PLOG_DEBUG << "Initializing SCTP transport";
mSock = usrsctp_socket(AF_CONN, SOCK_STREAM, IPPROTO_SCTP, nullptr, nullptr, 0, nullptr);
if (!mSock)
throw std::runtime_error("Could not create SCTP socket, errno=" + std::to_string(errno));
usrsctp_set_upcall(mSock, &SctpTransport::UpcallCallback, this);
if (usrsctp_set_non_blocking(mSock, 1))
throw std::runtime_error("Unable to set non-blocking mode, errno=" + std::to_string(errno));
struct linger sol = {};
sol.l_onoff = 1;
sol.l_linger = 0;
if (usrsctp_setsockopt(mSock, SOL_SOCKET, SO_LINGER, &sol, sizeof(sol)))
throw std::runtime_error("Could not set socket option SO_LINGER, errno=" +
std::to_string(errno));
struct sctp_assoc_value av = {};
av.assoc_id = SCTP_ALL_ASSOC;
av.assoc_value = 1;
if (usrsctp_setsockopt(mSock, IPPROTO_SCTP, SCTP_ENABLE_STREAM_RESET, &av, sizeof(av)))
throw std::runtime_error("Could not set socket option SCTP_ENABLE_STREAM_RESET, errno=" +
std::to_string(errno));
int on = 1;
if (usrsctp_setsockopt(mSock, IPPROTO_SCTP, SCTP_RECVRCVINFO, &on, sizeof(on)))
throw std::runtime_error("Could set socket option SCTP_RECVRCVINFO, errno=" +
std::to_string(errno));
struct sctp_event se = {};
se.se_assoc_id = SCTP_ALL_ASSOC;
se.se_on = 1;
se.se_type = SCTP_ASSOC_CHANGE;
if (usrsctp_setsockopt(mSock, IPPROTO_SCTP, SCTP_EVENT, &se, sizeof(se)))
throw std::runtime_error("Could not subscribe to event SCTP_ASSOC_CHANGE, errno=" +
std::to_string(errno));
se.se_type = SCTP_SENDER_DRY_EVENT;
if (usrsctp_setsockopt(mSock, IPPROTO_SCTP, SCTP_EVENT, &se, sizeof(se)))
throw std::runtime_error("Could not subscribe to event SCTP_SENDER_DRY_EVENT, errno=" +
std::to_string(errno));
se.se_type = SCTP_STREAM_RESET_EVENT;
if (usrsctp_setsockopt(mSock, IPPROTO_SCTP, SCTP_EVENT, &se, sizeof(se)))
throw std::runtime_error("Could not subscribe to event SCTP_STREAM_RESET_EVENT, errno=" +
std::to_string(errno));
int nodelay = 1;
if (usrsctp_setsockopt(mSock, IPPROTO_SCTP, SCTP_NODELAY, &nodelay, sizeof(nodelay)))
throw std::runtime_error("Could not set socket option SCTP_NODELAY, errno=" +
std::to_string(errno));
struct sctp_paddrparams spp = {};
spp.spp_flags = SPP_HB_ENABLE;
#if USE_PMTUD
if (!config.mtu.has_value()) {
#else
if (false) {
#endif
spp.spp_flags |= SPP_PMTUD_ENABLE;
PLOG_VERBOSE << "Path MTU discovery enabled";
} else {
spp.spp_flags |= SPP_PMTUD_DISABLE;
size_t pmtu = config.mtu.value_or(DEFAULT_MTU) - 12 - 48 - 8 - 40; spp.spp_pathmtu = to_uint32(pmtu);
PLOG_VERBOSE << "Path MTU discovery disabled, SCTP MTU set to " << pmtu;
}
if (usrsctp_setsockopt(mSock, IPPROTO_SCTP, SCTP_PEER_ADDR_PARAMS, &spp, sizeof(spp)))
throw std::runtime_error("Could not set socket option SCTP_PEER_ADDR_PARAMS, errno=" +
std::to_string(errno));
struct sctp_initmsg sinit = {};
sinit.sinit_num_ostreams = MAX_SCTP_STREAMS_COUNT;
sinit.sinit_max_instreams = MAX_SCTP_STREAMS_COUNT;
if (usrsctp_setsockopt(mSock, IPPROTO_SCTP, SCTP_INITMSG, &sinit, sizeof(sinit)))
throw std::runtime_error("Could not set socket option SCTP_INITMSG, errno=" +
std::to_string(errno));
int level = 0;
if (usrsctp_setsockopt(mSock, IPPROTO_SCTP, SCTP_FRAGMENT_INTERLEAVE, &level, sizeof(level)))
throw std::runtime_error("Could not disable SCTP fragmented interleave, errno=" +
std::to_string(errno));
#ifdef SCTP_ACCEPT_ZERO_CHECKSUM
int edmid = SCTP_EDMID_LOWER_LAYER_DTLS;
if (usrsctp_setsockopt(mSock, IPPROTO_SCTP, SCTP_ACCEPT_ZERO_CHECKSUM, &edmid, sizeof(edmid)))
throw std::runtime_error("Could set socket option SCTP_ACCEPT_ZERO_CHECKSUM, errno=" +
std::to_string(errno));
#endif
int rcvBuf = 0;
socklen_t rcvBufLen = sizeof(rcvBuf);
if (usrsctp_getsockopt(mSock, SOL_SOCKET, SO_RCVBUF, &rcvBuf, &rcvBufLen))
throw std::runtime_error("Could not get SCTP recv buffer size, errno=" +
std::to_string(errno));
int sndBuf = 0;
socklen_t sndBufLen = sizeof(sndBuf);
if (usrsctp_getsockopt(mSock, SOL_SOCKET, SO_SNDBUF, &sndBuf, &sndBufLen))
throw std::runtime_error("Could not get SCTP send buffer size, errno=" +
std::to_string(errno));
const int minBuf = int(std::min(mMaxMessageSize, size_t(std::numeric_limits<int>::max())));
rcvBuf = std::max(rcvBuf, minBuf);
sndBuf = std::max(sndBuf, minBuf);
if (usrsctp_setsockopt(mSock, SOL_SOCKET, SO_RCVBUF, &rcvBuf, sizeof(rcvBuf)))
throw std::runtime_error("Could not set SCTP recv buffer size, errno=" +
std::to_string(errno));
if (usrsctp_setsockopt(mSock, SOL_SOCKET, SO_SNDBUF, &sndBuf, sizeof(sndBuf)))
throw std::runtime_error("Could not set SCTP send buffer size, errno=" +
std::to_string(errno));
usrsctp_register_address(this);
Instances->insert(this);
}
SctpTransport::~SctpTransport() {
PLOG_DEBUG << "Destroying SCTP transport";
mProcessor.join();
mWrittenOnce = true;
mWrittenCondition.notify_all();
unregisterIncoming();
usrsctp_close(mSock);
usrsctp_deregister_address(this);
Instances->erase(this);
}
void SctpTransport::onBufferedAmount(amount_callback callback) {
mBufferedAmountCallback = std::move(callback);
}
void SctpTransport::start() {
registerIncoming();
connect();
}
void SctpTransport::stop() { close(); }
struct sockaddr_conn SctpTransport::getSockAddrConn(uint16_t port) {
struct sockaddr_conn sconn = {};
sconn.sconn_family = AF_CONN;
sconn.sconn_port = htons(port);
sconn.sconn_addr = this;
#ifdef HAVE_SCONN_LEN
sconn.sconn_len = sizeof(sconn);
#endif
return sconn;
}
void SctpTransport::connect() {
PLOG_DEBUG << "SCTP connecting (local port=" << mPorts.local
<< ", remote port=" << mPorts.remote << ")";
changeState(State::Connecting);
auto local = getSockAddrConn(mPorts.local);
if (usrsctp_bind(mSock, reinterpret_cast<struct sockaddr *>(&local), sizeof(local)))
throw std::runtime_error("Could not bind usrsctp socket, errno=" + std::to_string(errno));
auto remote = getSockAddrConn(mPorts.remote);
int ret = usrsctp_connect(mSock, reinterpret_cast<struct sockaddr *>(&remote), sizeof(remote));
if (ret && errno != EINPROGRESS)
throw std::runtime_error("Connection attempt failed, errno=" + std::to_string(errno));
}
bool SctpTransport::send(message_ptr message) {
std::lock_guard lock(mSendMutex);
if (state() != State::Connected)
return false;
if (!message)
return trySendQueue();
PLOG_VERBOSE << "Send size=" << message->size();
if (message->size() > mMaxMessageSize)
throw std::invalid_argument("Message is too large");
if (trySendQueue() && trySendMessage(message))
return true;
mSendQueue.push(message);
updateBufferedAmount(to_uint16(message->stream), ptrdiff_t(message_size_func(message)));
return false;
}
bool SctpTransport::flush() {
try {
std::lock_guard lock(mSendMutex);
if (state() != State::Connected)
return false;
trySendQueue();
return true;
} catch (const std::exception &e) {
PLOG_WARNING << "SCTP flush: " << e.what();
return false;
}
}
void SctpTransport::closeStream(unsigned int stream) {
std::lock_guard lock(mSendMutex);
mSendQueue.push(make_message(0, Message::Reset, to_uint16(stream)));
mProcessor.enqueue(&SctpTransport::flush, shared_from_this());
}
void SctpTransport::close() {
mSendQueue.stop();
if (state() == State::Connected) {
mProcessor.enqueue(&SctpTransport::flush, shared_from_this());
} else if (state() == State::Connecting) {
PLOG_DEBUG << "SCTP early shutdown";
if (usrsctp_shutdown(mSock, SHUT_RDWR)) {
if (errno == ENOTCONN) {
PLOG_VERBOSE << "SCTP already shut down";
} else {
PLOG_WARNING << "SCTP shutdown failed, errno=" << errno;
}
}
changeState(State::Failed);
mWrittenCondition.notify_all();
}
}
unsigned int SctpTransport::maxStream() const {
unsigned int streamsCount = mNegotiatedStreamsCount.value_or(MAX_SCTP_STREAMS_COUNT);
return streamsCount > 0 ? streamsCount - 1 : 0;
}
void SctpTransport::incoming(message_ptr message) {
if (!mWrittenOnce) { std::unique_lock lock(mWriteMutex);
mWrittenCondition.wait(lock, [&]() { return mWrittenOnce || state() == State::Failed; });
}
if (state() == State::Failed)
return;
if (!message) {
PLOG_INFO << "SCTP disconnected";
changeState(State::Disconnected);
recv(nullptr);
return;
}
PLOG_VERBOSE << "Incoming size=" << message->size();
usrsctp_conninput(this, message->data(), message->size(), 0);
}
bool SctpTransport::outgoing(message_ptr message) {
message->dscp = 10; return Transport::outgoing(std::move(message));
}
void SctpTransport::doRecv() {
std::lock_guard lock(mRecvMutex);
--mPendingRecvCount;
try {
while (state() != State::Disconnected && state() != State::Failed) {
const size_t bufferSize = 65536;
byte buffer[bufferSize];
socklen_t fromlen = 0;
struct sctp_rcvinfo info = {};
socklen_t infolen = sizeof(info);
unsigned int infotype = 0;
int flags = 0;
ssize_t len = usrsctp_recvv(mSock, buffer, bufferSize, nullptr, &fromlen, &info,
&infolen, &infotype, &flags);
if (len < 0) {
if (errno == EWOULDBLOCK || errno == EAGAIN || errno == ECONNRESET)
break;
else
throw std::runtime_error("SCTP recv failed, errno=" + std::to_string(errno));
} else if (len == 0) {
break;
}
PLOG_VERBOSE << "SCTP recv, len=" << len;
if (flags & MSG_NOTIFICATION) {
mPartialNotification.insert(mPartialNotification.end(), buffer, buffer + len);
if (flags & MSG_EOR) {
binary notification;
mPartialNotification.swap(notification);
auto n = reinterpret_cast<union sctp_notification *>(notification.data());
processNotification(n, notification.size());
}
} else {
mPartialMessage.insert(mPartialMessage.end(), buffer, buffer + len);
if (mPartialMessage.size() > mMaxMessageSize) {
PLOG_WARNING << "SCTP message is too large, truncating it";
mPartialMessage.resize(mMaxMessageSize);
}
if (flags & MSG_EOR) {
binary message;
mPartialMessage.swap(message);
if (infotype != SCTP_RECVV_RCVINFO)
throw std::runtime_error("Missing SCTP recv info");
processData(std::move(message), info.rcv_sid, PayloadId(ntohl(info.rcv_ppid)));
}
}
}
} catch (const std::exception &e) {
PLOG_WARNING << e.what();
}
}
void SctpTransport::doFlush() {
std::lock_guard lock(mSendMutex);
--mPendingFlushCount;
try {
trySendQueue();
} catch (const std::exception &e) {
PLOG_WARNING << e.what();
}
}
void SctpTransport::enqueueRecv() {
if (mPendingRecvCount > 0)
return;
if (auto shared_this = weak_from_this().lock()) {
++mPendingRecvCount;
mProcessor.enqueue(&SctpTransport::doRecv, std::move(shared_this));
}
}
void SctpTransport::enqueueFlush() {
if (mPendingFlushCount > 0)
return;
if (auto shared_this = weak_from_this().lock()) {
++mPendingFlushCount;
mProcessor.enqueue(&SctpTransport::doFlush, std::move(shared_this));
}
}
bool SctpTransport::trySendQueue() {
while (auto next = mSendQueue.peek()) {
message_ptr message = std::move(*next);
if (!trySendMessage(message))
return false;
mSendQueue.pop();
updateBufferedAmount(to_uint16(message->stream), -ptrdiff_t(message_size_func(message)));
}
if (!mSendQueue.running() && !std::exchange(mSendShutdown, true)) {
PLOG_DEBUG << "SCTP shutdown";
if (usrsctp_shutdown(mSock, SHUT_WR)) {
if (errno == ENOTCONN) {
PLOG_VERBOSE << "SCTP already shut down";
} else {
PLOG_WARNING << "SCTP shutdown failed, errno=" << errno;
changeState(State::Disconnected);
recv(nullptr);
}
}
}
return true;
}
bool SctpTransport::trySendMessage(message_ptr message) {
if (state() != State::Connected)
return false;
uint32_t ppid;
switch (message->type) {
case Message::String:
ppid = !message->empty() ? PPID_STRING : PPID_STRING_EMPTY;
break;
case Message::Binary:
ppid = !message->empty() ? PPID_BINARY : PPID_BINARY_EMPTY;
break;
case Message::Control:
ppid = PPID_CONTROL;
break;
case Message::Reset:
sendReset(uint16_t(message->stream));
return true;
default:
return true;
}
PLOG_VERBOSE << "SCTP try send size=" << message->size();
const Reliability reliability = message->reliability ? *message->reliability : Reliability();
struct sctp_sendv_spa spa = {};
spa.sendv_flags |= SCTP_SEND_SNDINFO_VALID;
spa.sendv_sndinfo.snd_sid = uint16_t(message->stream);
spa.sendv_sndinfo.snd_ppid = htonl(ppid);
spa.sendv_sndinfo.snd_flags |= SCTP_EOR;
spa.sendv_flags |= SCTP_SEND_PRINFO_VALID;
if (reliability.unordered)
spa.sendv_sndinfo.snd_flags |= SCTP_UNORDERED;
if (reliability.maxPacketLifeTime) {
spa.sendv_flags |= SCTP_SEND_PRINFO_VALID;
spa.sendv_prinfo.pr_policy = SCTP_PR_SCTP_TTL;
spa.sendv_prinfo.pr_value = to_uint32(reliability.maxPacketLifeTime->count());
} else if (reliability.maxRetransmits) {
spa.sendv_flags |= SCTP_SEND_PRINFO_VALID;
spa.sendv_prinfo.pr_policy = SCTP_PR_SCTP_RTX;
spa.sendv_prinfo.pr_value = to_uint32(*reliability.maxRetransmits);
}
else switch (reliability.typeDeprecated) {
case Reliability::Type::Rexmit:
spa.sendv_flags |= SCTP_SEND_PRINFO_VALID;
spa.sendv_prinfo.pr_policy = SCTP_PR_SCTP_RTX;
spa.sendv_prinfo.pr_value = to_uint32(std::get<int>(reliability.rexmit));
break;
case Reliability::Type::Timed:
spa.sendv_flags |= SCTP_SEND_PRINFO_VALID;
spa.sendv_prinfo.pr_policy = SCTP_PR_SCTP_TTL;
spa.sendv_prinfo.pr_value = to_uint32(std::get<milliseconds>(reliability.rexmit).count());
break;
default:
spa.sendv_prinfo.pr_policy = SCTP_PR_SCTP_NONE;
break;
}
ssize_t ret;
if (!message->empty()) {
ret = usrsctp_sendv(mSock, message->data(), message->size(), nullptr, 0, &spa, sizeof(spa),
SCTP_SENDV_SPA, 0);
} else {
const char zero = 0;
ret = usrsctp_sendv(mSock, &zero, 1, nullptr, 0, &spa, sizeof(spa), SCTP_SENDV_SPA, 0);
}
if (ret < 0) {
if (errno == EWOULDBLOCK || errno == EAGAIN) {
PLOG_VERBOSE << "SCTP sending not possible";
return false;
}
PLOG_ERROR << "SCTP sending failed, errno=" << errno;
throw std::runtime_error("Sending failed, errno=" + std::to_string(errno));
}
PLOG_VERBOSE << "SCTP sent size=" << message->size();
if (message->type == Message::Binary || message->type == Message::String)
mBytesSent += message->size();
return true;
}
void SctpTransport::updateBufferedAmount(uint16_t streamId, ptrdiff_t delta) {
if (delta == 0)
return;
auto it = mBufferedAmount.insert(std::make_pair(streamId, 0)).first;
size_t amount = size_t(std::max(ptrdiff_t(it->second) + delta, ptrdiff_t(0)));
if (amount == 0)
mBufferedAmount.erase(it);
else
it->second = amount;
triggerBufferedAmount(streamId, amount);
}
void SctpTransport::triggerBufferedAmount(uint16_t streamId, size_t amount) {
try {
mBufferedAmountCallback(streamId, amount);
} catch (const std::exception &e) {
PLOG_WARNING << "SCTP buffered amount callback: " << e.what();
}
}
void SctpTransport::sendReset(uint16_t streamId) {
if (state() != State::Connected)
return;
PLOG_DEBUG << "SCTP resetting stream " << streamId;
using srs_t = struct sctp_reset_streams;
const size_t len = sizeof(srs_t) + sizeof(uint16_t);
byte buffer[len] = {};
srs_t &srs = *reinterpret_cast<srs_t *>(buffer);
srs.srs_flags = SCTP_STREAM_RESET_OUTGOING;
srs.srs_number_streams = 1;
srs.srs_stream_list[0] = streamId;
mWritten = false;
if (usrsctp_setsockopt(mSock, IPPROTO_SCTP, SCTP_RESET_STREAMS, &srs, len) == 0) {
std::unique_lock lock(mWriteMutex); mWrittenCondition.wait_for(lock, 1000ms,
[&]() { return mWritten || state() != State::Connected; });
} else if (errno == EINVAL) {
PLOG_DEBUG << "SCTP stream " << streamId << " already reset";
} else {
PLOG_WARNING << "SCTP reset stream " << streamId << " failed, errno=" << errno;
}
}
void SctpTransport::handleUpcall() noexcept {
try {
PLOG_VERBOSE << "Handle upcall";
int events = usrsctp_get_events(mSock);
if (events & SCTP_EVENT_READ)
enqueueRecv();
if (events & SCTP_EVENT_WRITE)
enqueueFlush();
} catch (const std::exception &e) {
PLOG_ERROR << "SCTP upcall: " << e.what();
}
}
int SctpTransport::handleWrite(byte *data, size_t len, uint8_t ,
uint8_t ) noexcept {
try {
std::unique_lock lock(mWriteMutex);
PLOG_VERBOSE << "Handle write, len=" << len;
if (!outgoing(make_message(data, data + len)))
return -1;
mWritten = true;
mWrittenOnce = true;
mWrittenCondition.notify_all();
} catch (const std::exception &e) {
PLOG_ERROR << "SCTP write: " << e.what();
return -1;
}
return 0; }
void SctpTransport::processData(binary &&data, uint16_t sid, PayloadId ppid) {
PLOG_VERBOSE << "Process data, size=" << data.size();
switch (ppid) {
case PPID_CONTROL:
recv(make_message(std::move(data), Message::Control, sid));
break;
case PPID_STRING_PARTIAL: mPartialStringData.insert(mPartialStringData.end(), data.begin(), data.end());
mPartialStringData.resize(mMaxMessageSize);
break;
case PPID_STRING:
if (mPartialStringData.empty()) {
mBytesReceived += data.size();
recv(make_message(std::move(data), Message::String, sid));
} else {
mPartialStringData.insert(mPartialStringData.end(), data.begin(), data.end());
mPartialStringData.resize(mMaxMessageSize);
mBytesReceived += mPartialStringData.size();
auto message = make_message(std::move(mPartialStringData), Message::String, sid);
mPartialStringData.clear();
recv(std::move(message));
}
break;
case PPID_STRING_EMPTY:
recv(make_message(std::move(mPartialStringData), Message::String, sid));
mPartialStringData.clear();
break;
case PPID_BINARY_PARTIAL: mPartialBinaryData.insert(mPartialBinaryData.end(), data.begin(), data.end());
mPartialBinaryData.resize(mMaxMessageSize);
break;
case PPID_BINARY:
if (mPartialBinaryData.empty()) {
mBytesReceived += data.size();
recv(make_message(std::move(data), Message::Binary, sid));
} else {
mPartialBinaryData.insert(mPartialBinaryData.end(), data.begin(), data.end());
mPartialBinaryData.resize(mMaxMessageSize);
mBytesReceived += mPartialBinaryData.size();
auto message = make_message(std::move(mPartialBinaryData), Message::Binary, sid);
mPartialBinaryData.clear();
recv(std::move(message));
}
break;
case PPID_BINARY_EMPTY:
recv(make_message(std::move(mPartialBinaryData), Message::Binary, sid));
mPartialBinaryData.clear();
break;
default:
COUNTER_UNKNOWN_PPID++;
PLOG_VERBOSE << "Unknown PPID: " << uint32_t(ppid);
return;
}
}
void SctpTransport::processNotification(const union sctp_notification *notify, size_t len) {
if (len != size_t(notify->sn_header.sn_length)) {
PLOG_WARNING << "Unexpected notification length, expected=" << notify->sn_header.sn_length
<< ", actual=" << len;
return;
}
auto type = notify->sn_header.sn_type;
PLOG_VERBOSE << "Processing notification, type=" << type;
switch (type) {
case SCTP_ASSOC_CHANGE: {
PLOG_VERBOSE << "SCTP association change event";
const struct sctp_assoc_change &sac = notify->sn_assoc_change;
if (sac.sac_state == SCTP_COMM_UP) {
PLOG_DEBUG << "SCTP negotiated streams: incoming=" << sac.sac_inbound_streams
<< ", outgoing=" << sac.sac_outbound_streams;
mNegotiatedStreamsCount.emplace(
std::min(sac.sac_inbound_streams, sac.sac_outbound_streams));
PLOG_INFO << "SCTP connected";
changeState(State::Connected);
} else {
if (state() == State::Connected) {
PLOG_INFO << "SCTP disconnected";
changeState(State::Disconnected);
recv(nullptr);
} else {
PLOG_ERROR << "SCTP connection failed";
changeState(State::Failed);
}
mWrittenCondition.notify_all();
}
break;
}
case SCTP_SENDER_DRY_EVENT: {
PLOG_VERBOSE << "SCTP sender dry event";
flush();
break;
}
case SCTP_STREAM_RESET_EVENT: {
const struct sctp_stream_reset_event &reset_event = notify->sn_strreset_event;
const int count = (reset_event.strreset_length - sizeof(reset_event)) / sizeof(uint16_t);
const uint16_t flags = reset_event.strreset_flags;
IF_PLOG(plog::verbose) {
std::ostringstream desc;
desc << "flags=";
if (flags & SCTP_STREAM_RESET_OUTGOING_SSN && flags & SCTP_STREAM_RESET_INCOMING_SSN)
desc << "outgoing|incoming";
else if (flags & SCTP_STREAM_RESET_OUTGOING_SSN)
desc << "outgoing";
else if (flags & SCTP_STREAM_RESET_INCOMING_SSN)
desc << "incoming";
else
desc << "0";
desc << ", streams=[";
for (int i = 0; i < count; ++i) {
uint16_t streamId = reset_event.strreset_stream_list[i];
desc << (i != 0 ? "," : "") << streamId;
}
desc << "]";
PLOG_VERBOSE << "SCTP reset event, " << desc.str();
}
if (flags & SCTP_STREAM_RESET_INCOMING_SSN) {
for (int i = 0; i < count; ++i) {
uint16_t streamId = reset_event.strreset_stream_list[i];
recv(make_message(0, Message::Reset, streamId));
}
}
break;
}
default:
break;
}
}
void SctpTransport::clearStats() {
mBytesReceived = 0;
mBytesSent = 0;
}
size_t SctpTransport::bytesSent() { return mBytesSent; }
size_t SctpTransport::bytesReceived() { return mBytesReceived; }
optional<milliseconds> SctpTransport::rtt() {
if (state() != State::Connected)
return nullopt;
struct sctp_status status = {};
socklen_t len = sizeof(status);
if (usrsctp_getsockopt(mSock, IPPROTO_SCTP, SCTP_STATUS, &status, &len))
return nullopt;
return milliseconds(status.sstat_primary.spinfo_srtt);
}
void SctpTransport::UpcallCallback(struct socket *, void *arg, int ) {
auto *transport = static_cast<SctpTransport *>(arg);
if (auto locked = Instances->lock(transport))
transport->handleUpcall();
}
int SctpTransport::WriteCallback(void *ptr, void *data, size_t len, uint8_t tos, uint8_t set_df) {
auto *transport = static_cast<SctpTransport *>(ptr);
#ifndef SCTP_ACCEPT_ZERO_CHECKSUM
if (len >= 12) {
uint32_t *checksum = reinterpret_cast<uint32_t *>(data) + 2;
*checksum = 0;
*checksum = usrsctp_crc32c(data, len);
}
#endif
if (auto locked = Instances->lock(transport))
return transport->handleWrite(static_cast<byte *>(data), len, tos, set_df);
else
return -1;
}
void SctpTransport::DebugCallback(const char *format, ...) {
const size_t bufferSize = 1024;
char buffer[bufferSize];
va_list va;
va_start(va, format);
int len = std::vsnprintf(buffer, bufferSize, format, va);
va_end(va);
if (len <= 0)
return;
len = std::min(len, int(bufferSize - 1));
buffer[len - 1] = '\0';
PLOG_VERBOSE << "usrsctp: " << buffer; }
}