#define MS_CLASS "PayloadChannel::PayloadChannelSocket"
#include "PayloadChannel/PayloadChannelSocket.hpp"
#include "Logger.hpp"
#include "MediaSoupErrors.hpp"
#include "PayloadChannel/PayloadChannelRequest.hpp"
#include <cmath>
#include <cstdio>
#include <cstring>
extern "C"
{
#include <netstring.h>
}
namespace PayloadChannel
{
static constexpr size_t NsMessageMaxLen{ 4194313 };
static constexpr size_t NsPayloadMaxLen{ 4194304 };
PayloadChannelSocket::PayloadChannelSocket(int consumerFd, int producerFd)
: consumerSocket(consumerFd, NsMessageMaxLen, this), producerSocket(producerFd, NsMessageMaxLen)
{
MS_TRACE();
this->writeBuffer = (uint8_t*)std::malloc(NsMessageMaxLen);
}
PayloadChannelSocket::~PayloadChannelSocket()
{
MS_TRACE();
std::free(this->writeBuffer);
delete this->ongoingNotification;
if (!this->closed)
Close();
}
void PayloadChannelSocket::Close()
{
MS_TRACE_STD();
if (this->closed)
return;
this->closed = true;
this->consumerSocket.Close();
this->producerSocket.Close();
}
void PayloadChannelSocket::SetListener(Listener* listener)
{
MS_TRACE();
this->listener = listener;
}
void PayloadChannelSocket::Send(json& jsonMessage, const uint8_t* payload, size_t payloadLen)
{
MS_TRACE();
if (this->closed)
return;
std::string message = jsonMessage.dump();
if (message.length() > NsPayloadMaxLen)
{
MS_ERROR("mesage too big");
return;
}
else if (payloadLen > NsPayloadMaxLen)
{
MS_ERROR("payload too big");
return;
}
SendImpl(message.c_str(), message.length());
SendImpl(payload, payloadLen);
}
void PayloadChannelSocket::Send(json& jsonMessage)
{
MS_TRACE_STD();
if (this->closed)
return;
std::string message = jsonMessage.dump();
if (message.length() > NsPayloadMaxLen)
{
MS_ERROR_STD("mesage too big");
return;
}
SendImpl(message.c_str(), message.length());
}
inline void PayloadChannelSocket::SendImpl(const void* nsPayload, size_t nsPayloadLen)
{
MS_TRACE();
size_t nsNumLen;
if (nsPayloadLen == 0)
{
nsNumLen = 1;
this->writeBuffer[0] = '0';
this->writeBuffer[1] = ':';
this->writeBuffer[2] = ',';
}
else
{
nsNumLen = static_cast<size_t>(std::ceil(std::log10(static_cast<double>(nsPayloadLen) + 1)));
std::sprintf(reinterpret_cast<char*>(this->writeBuffer), "%zu:", nsPayloadLen);
std::memcpy(this->writeBuffer + nsNumLen + 1, nsPayload, nsPayloadLen);
this->writeBuffer[nsNumLen + nsPayloadLen + 1] = ',';
}
size_t nsLen = nsNumLen + nsPayloadLen + 2;
this->producerSocket.Write(this->writeBuffer, nsLen);
}
void PayloadChannelSocket::OnConsumerSocketMessage(
ConsumerSocket* , char* msg, size_t msgLen)
{
MS_TRACE();
if (!this->ongoingNotification && !this->ongoingRequest)
{
json jsonData = json::parse(msg, msg + msgLen);
if (PayloadChannelRequest::IsRequest(jsonData))
{
try
{
json jsonMessage = json::parse(msg, msg + msgLen);
this->ongoingRequest = new PayloadChannel::PayloadChannelRequest(this, jsonMessage);
}
catch (const json::parse_error& error)
{
MS_ERROR_STD("JSON parsing error: %s", error.what());
}
catch (const MediaSoupError& error)
{
MS_ERROR("discarding wrong Payload Channel notification");
}
}
else if (Notification::IsNotification(jsonData))
{
try
{
json jsonMessage = json::parse(msg, msg + msgLen);
this->ongoingNotification = new PayloadChannel::Notification(jsonMessage);
}
catch (const json::parse_error& error)
{
MS_ERROR_STD("JSON parsing error: %s", error.what());
}
catch (const MediaSoupError& error)
{
MS_ERROR("discarding wrong Payload Channel notification");
}
}
else
{
MS_ERROR("discarding wrong Payload Channel data");
}
}
else if (this->ongoingNotification)
{
this->ongoingNotification->SetPayload(reinterpret_cast<const uint8_t*>(msg), msgLen);
try
{
this->listener->OnPayloadChannelNotification(this, this->ongoingNotification);
}
catch (const MediaSoupError& error)
{
MS_ERROR("notification failed: %s", error.what());
}
delete this->ongoingNotification;
this->ongoingNotification = nullptr;
}
else if (this->ongoingRequest)
{
this->ongoingRequest->SetPayload(reinterpret_cast<const uint8_t*>(msg), msgLen);
try
{
this->listener->OnPayloadChannelRequest(this, this->ongoingRequest);
}
catch (const MediaSoupTypeError& error)
{
this->ongoingRequest->TypeError(error.what());
}
catch (const MediaSoupError& error)
{
this->ongoingRequest->Error(error.what());
}
delete this->ongoingRequest;
this->ongoingRequest = nullptr;
}
}
void PayloadChannelSocket::OnConsumerSocketClosed(ConsumerSocket* )
{
MS_TRACE();
this->listener->OnPayloadChannelClosed(this);
}
ConsumerSocket::ConsumerSocket(int fd, size_t bufferSize, Listener* listener)
: ::UnixStreamSocket(fd, bufferSize, ::UnixStreamSocket::Role::CONSUMER), listener(listener)
{
MS_TRACE();
this->readBuffer = static_cast<uint8_t*>(std::malloc(NsMessageMaxLen));
}
ConsumerSocket::~ConsumerSocket()
{
MS_TRACE();
std::free(this->readBuffer);
}
void ConsumerSocket::UserOnUnixStreamRead()
{
MS_TRACE();
while (true)
{
if (IsClosed())
return;
size_t readLen = this->bufferDataLen - this->msgStart;
char* msgStart = nullptr;
size_t msgLen;
int nsRet = netstring_read(
reinterpret_cast<char*>(this->buffer + this->msgStart), readLen, &msgStart, &msgLen);
if (nsRet != 0)
{
switch (nsRet)
{
case NETSTRING_ERROR_TOO_SHORT:
{
if (this->bufferDataLen == this->bufferSize)
{
if (this->msgStart != 0)
{
std::memmove(this->buffer, this->buffer + this->msgStart, readLen);
this->msgStart = 0;
this->bufferDataLen = readLen;
}
else
{
MS_ERROR(
"no more space in the buffer for the unfinished message being parsed, "
"discarding it");
this->msgStart = 0;
this->bufferDataLen = 0;
}
}
return;
}
case NETSTRING_ERROR_TOO_LONG:
{
MS_ERROR("NETSTRING_ERROR_TOO_LONG");
break;
}
case NETSTRING_ERROR_NO_COLON:
{
MS_ERROR("NETSTRING_ERROR_NO_COLON");
break;
}
case NETSTRING_ERROR_NO_COMMA:
{
MS_ERROR("NETSTRING_ERROR_NO_COMMA");
break;
}
case NETSTRING_ERROR_LEADING_ZERO:
{
MS_ERROR("NETSTRING_ERROR_LEADING_ZERO");
break;
}
case NETSTRING_ERROR_NO_LENGTH:
{
MS_ERROR("NETSTRING_ERROR_NO_LENGTH");
break;
}
}
this->msgStart = 0;
this->bufferDataLen = 0;
return;
}
readLen =
reinterpret_cast<const uint8_t*>(msgStart) - (this->buffer + this->msgStart) + msgLen + 1;
std::memcpy(this->readBuffer, msgStart, msgLen);
this->listener->OnConsumerSocketMessage(this, reinterpret_cast<char*>(this->readBuffer), msgLen);
if ((this->msgStart + readLen) == this->bufferSize)
{
this->msgStart = 0;
this->bufferDataLen = 0;
}
else
{
this->msgStart += readLen;
}
if (this->bufferDataLen > this->msgStart)
{
continue;
}
break;
}
}
void ConsumerSocket::UserOnUnixStreamSocketClosed()
{
MS_TRACE();
this->listener->OnConsumerSocketClosed(this);
}
ProducerSocket::ProducerSocket(int fd, size_t bufferSize)
: ::UnixStreamSocket(fd, bufferSize, ::UnixStreamSocket::Role::PRODUCER)
{
MS_TRACE();
}
}