#define MS_CLASS "Channel::ChannelSocket"
#include "Channel/ChannelSocket.hpp"
#include "DepLibUV.hpp"
#include "Logger.hpp"
#include "MediaSoupErrors.hpp"
#include <cstring>
namespace Channel
{
static constexpr size_t MessageMaxLen{ 8388608 };
static constexpr size_t PayloadMaxLen{ MessageMaxLen - 4 };
inline static void onAsync(uv_handle_t* handle)
{
while (static_cast<ChannelSocket*>(handle->data)->CallbackRead())
{
}
}
inline static void onCloseAsync(uv_handle_t* handle)
{
delete reinterpret_cast<uv_async_t*>(handle);
}
#if defined(MS_TEST) || defined(MS_FUZZER)
ChannelSocket::ChannelSocket()
{
MS_TRACE_STD();
}
#endif
ChannelSocket::ChannelSocket(int consumerFd, int producerFd)
: consumerSocket(new ConsumerSocket(consumerFd, MessageMaxLen, this)),
producerSocket(new ProducerSocket(producerFd, MessageMaxLen))
{
MS_TRACE_STD();
}
ChannelSocket::ChannelSocket(
ChannelReadFn channelReadFn,
ChannelReadCtx channelReadCtx,
ChannelWriteFn channelWriteFn,
ChannelWriteCtx channelWriteCtx)
: channelReadFn(channelReadFn),
channelReadCtx(channelReadCtx),
channelWriteFn(channelWriteFn),
channelWriteCtx(channelWriteCtx),
uvReadHandle(new uv_async_t)
{
MS_TRACE_STD();
int err;
this->uvReadHandle->data = static_cast<void*>(this);
err =
uv_async_init(DepLibUV::GetLoop(), this->uvReadHandle, reinterpret_cast<uv_async_cb>(onAsync));
if (err != 0)
{
delete this->uvReadHandle;
this->uvReadHandle = nullptr;
MS_THROW_ERROR_STD("uv_async_init() failed: %s", uv_strerror(err));
}
err = uv_async_send(this->uvReadHandle);
if (err != 0)
{
delete this->uvReadHandle;
this->uvReadHandle = nullptr;
MS_THROW_ERROR_STD("uv_async_send() failed: %s", uv_strerror(err));
}
}
ChannelSocket::~ChannelSocket()
{
MS_TRACE_STD();
if (!this->closed)
{
Close();
}
delete this->consumerSocket;
delete this->producerSocket;
}
void ChannelSocket::Close()
{
MS_TRACE_STD();
if (this->closed)
{
return;
}
this->closed = true;
if (this->uvReadHandle)
{
uv_close(
reinterpret_cast<uv_handle_t*>(this->uvReadHandle), static_cast<uv_close_cb>(onCloseAsync));
}
if (this->consumerSocket)
{
this->consumerSocket->Close();
}
if (this->producerSocket)
{
this->producerSocket->Close();
}
}
void ChannelSocket::SetListener(Listener* listener)
{
MS_TRACE_STD();
this->listener = listener;
}
void ChannelSocket::Send(const uint8_t* data, uint32_t dataLen)
{
MS_TRACE_STD();
if (this->closed)
{
return;
}
if (dataLen > PayloadMaxLen)
{
MS_ERROR_STD("message too big");
return;
}
SendImpl(data, dataLen);
}
void ChannelSocket::SendLog(const char* data, uint32_t dataLen)
{
MS_TRACE_STD();
if (this->closed)
{
return;
}
if (dataLen > PayloadMaxLen)
{
MS_ERROR_STD("message too big");
return;
}
auto& builder = this->bufferBuilder;
auto log = FBS::Log::CreateLogDirect(builder, data);
auto message = FBS::Message::CreateMessage(builder, FBS::Message::Body::Log, log.Union());
builder.FinishSizePrefixed(message);
this->Send(builder.GetBufferPointer(), builder.GetSize());
builder.Clear();
}
bool ChannelSocket::CallbackRead()
{
MS_TRACE_STD();
if (this->closed)
{
return false;
}
uint8_t* msg{ nullptr };
uint32_t msgLen;
size_t msgCtx;
auto free = this->channelReadFn(
std::addressof(msg),
std::addressof(msgLen),
std::addressof(msgCtx),
this->uvReadHandle,
this->channelReadCtx);
if (free)
{
const auto* message = FBS::Message::GetMessage(msg);
#if MS_LOG_DEV_LEVEL == 3
auto s = flatbuffers::FlatBufferToString(
reinterpret_cast<uint8_t*>(msg), FBS::Message::MessageTypeTable());
MS_DUMP("%s", s.c_str());
#endif
if (message->data_type() == FBS::Message::Body::Request)
{
ChannelRequest* request{ nullptr };
try
{
request = new ChannelRequest(this, message->data_as<FBS::Request::Request>());
this->listener->HandleRequest(request);
}
catch (const MediaSoupTypeError& error)
{
request->TypeError(error.what());
}
catch (const MediaSoupError& error)
{
request->Error(error.what());
}
delete request;
}
else if (message->data_type() == FBS::Message::Body::Notification)
{
ChannelNotification* notification{ nullptr };
try
{
notification = new ChannelNotification(message->data_as<FBS::Notification::Notification>());
this->listener->HandleNotification(notification);
}
catch (const MediaSoupError& error)
{
MS_ERROR("notification failed: %s", error.what());
}
delete notification;
}
else
{
MS_ERROR("discarding wrong Channel data");
}
free(msg, msgLen, msgCtx);
}
return free != nullptr;
}
void ChannelSocket::SendImpl(const uint8_t* payload, uint32_t payloadLen)
{
MS_TRACE_STD();
if (this->channelWriteFn)
{
this->channelWriteFn(payload, payloadLen, this->channelWriteCtx);
}
else if (this->producerSocket)
{
this->producerSocket->Write(payload, payloadLen);
}
}
void ChannelSocket::OnConsumerSocketMessage(
const ConsumerSocket* , char* msg, size_t )
{
MS_TRACE();
const auto* message = FBS::Message::GetMessage(msg);
#if MS_LOG_DEV_LEVEL == 3
auto s = flatbuffers::FlatBufferToString(
reinterpret_cast<uint8_t*>(msg), FBS::Message::MessageTypeTable());
MS_DUMP("%s", s.c_str());
#endif
if (message->data_type() == FBS::Message::Body::Request)
{
ChannelRequest* request{ nullptr };
try
{
request = new ChannelRequest(this, message->data_as<FBS::Request::Request>());
this->listener->HandleRequest(request);
}
catch (const MediaSoupTypeError& error)
{
request->TypeError(error.what());
}
catch (const MediaSoupError& error)
{
request->Error(error.what());
}
delete request;
}
else if (message->data_type() == FBS::Message::Body::Notification)
{
ChannelNotification* notification{ nullptr };
try
{
notification = new ChannelNotification(message->data_as<FBS::Notification::Notification>());
this->listener->HandleNotification(notification);
}
catch (const MediaSoupError& error)
{
MS_ERROR("notification failed: %s", error.what());
}
delete notification;
}
else
{
MS_ERROR("discarding wrong Channel data");
}
}
void ChannelSocket::OnConsumerSocketClosed(const ConsumerSocket* )
{
MS_TRACE_STD();
this->listener->OnChannelClosed(this);
}
ConsumerSocket::ConsumerSocket(int fd, size_t bufferSize, Listener* listener)
: ::UnixStreamSocketHandle(fd, bufferSize, ::UnixStreamSocketHandle::Role::CONSUMER),
listener(listener)
{
MS_TRACE_STD();
}
ConsumerSocket::~ConsumerSocket()
{
MS_TRACE_STD();
}
void ConsumerSocket::UserOnUnixStreamRead()
{
MS_TRACE_STD();
size_t msgStart{ 0 };
while (true)
{
if (IsClosed())
{
return;
}
const size_t readLen = this->bufferDataLen - msgStart;
if (readLen < sizeof(uint32_t))
{
break;
}
uint32_t msgLen;
std::memcpy(std::addressof(msgLen), this->buffer + msgStart, sizeof(uint32_t));
if (readLen < sizeof(uint32_t) + static_cast<size_t>(msgLen))
{
break;
}
this->listener->OnConsumerSocketMessage(
this,
reinterpret_cast<char*>(this->buffer + msgStart + sizeof(uint32_t)),
static_cast<size_t>(msgLen));
msgStart += sizeof(uint32_t) + static_cast<size_t>(msgLen);
}
if (msgStart != 0)
{
this->bufferDataLen = this->bufferDataLen - msgStart;
if (this->bufferDataLen != 0)
{
std::memmove(this->buffer, this->buffer + msgStart, this->bufferDataLen);
}
}
}
void ConsumerSocket::UserOnUnixStreamSocketClosed()
{
MS_TRACE_STD();
this->listener->OnConsumerSocketClosed(this);
}
ProducerSocket::ProducerSocket(int fd, size_t bufferSize)
: ::UnixStreamSocketHandle(fd, bufferSize, ::UnixStreamSocketHandle::Role::PRODUCER)
{
MS_TRACE_STD();
}
}