#define MS_CLASS "UdpSocketHandle"
#include "handles/UdpSocketHandle.hpp"
#ifdef MS_LIBURING_SUPPORTED
#include "DepLibUring.hpp"
#endif
#include "Logger.hpp"
#include "MediaSoupErrors.hpp"
#include "Utils.hpp"
#include <cstring>
static constexpr size_t ReadBufferSize{ 65536 };
alignas(4) static thread_local uint8_t ReadBuffer[ReadBufferSize];
inline static void onAlloc(uv_handle_t* handle, size_t suggestedSize, uv_buf_t* buf)
{
auto* socket = static_cast<UdpSocketHandle*>(handle->data);
if (socket)
{
socket->OnUvRecvAlloc(suggestedSize, buf);
}
}
inline static void onRecv(
uv_udp_t* handle, ssize_t nread, const uv_buf_t* buf, const struct sockaddr* addr, unsigned int flags)
{
auto* socket = static_cast<UdpSocketHandle*>(handle->data);
if (socket)
{
socket->OnUvRecv(nread, buf, addr, flags);
}
}
inline static void onSend(uv_udp_send_t* req, int status)
{
auto* sendData = static_cast<UdpSocketHandle::UvSendData*>(req->data);
auto* handle = req->handle;
auto* socket = static_cast<UdpSocketHandle*>(handle->data);
const auto* cb = sendData->cb;
if (socket)
{
socket->OnUvSend(status, cb);
}
delete sendData;
}
inline static void onCloseUdp(uv_handle_t* handle)
{
delete reinterpret_cast<uv_udp_t*>(handle);
}
UdpSocketHandle::UdpSocketHandle(uv_udp_t* uvHandle) : uvHandle(uvHandle)
{
MS_TRACE();
this->uvHandle->data = static_cast<void*>(this);
int err = uv_udp_recv_start(
this->uvHandle, static_cast<uv_alloc_cb>(onAlloc), static_cast<uv_udp_recv_cb>(onRecv));
if (err != 0)
{
uv_close(reinterpret_cast<uv_handle_t*>(this->uvHandle), static_cast<uv_close_cb>(onCloseUdp));
MS_THROW_ERROR("uv_udp_recv_start() failed: %s", uv_strerror(err));
}
if (!SetLocalAddress())
{
uv_close(reinterpret_cast<uv_handle_t*>(this->uvHandle), static_cast<uv_close_cb>(onCloseUdp));
MS_THROW_ERROR("error setting local IP and port");
}
#ifdef MS_LIBURING_SUPPORTED
if (DepLibUring::IsEnabled())
{
err = uv_fileno(reinterpret_cast<uv_handle_t*>(this->uvHandle), std::addressof(this->fd));
if (err != 0)
{
MS_THROW_ERROR("uv_fileno() failed: %s", uv_strerror(err));
}
}
#endif
}
UdpSocketHandle::~UdpSocketHandle()
{
MS_TRACE();
if (!this->closed)
{
try
{
InternalClose();
}
catch (const std::exception& e)
{
MS_ERROR("error closing UDP socket: %s", e.what());
}
}
}
void UdpSocketHandle::Dump(int indentation) const
{
MS_DUMP_CLEAN(indentation, "<UdpSocketHandle>");
MS_DUMP_CLEAN(indentation, " local IP: %s", this->localIp.c_str());
MS_DUMP_CLEAN(indentation, " local port: %" PRIu16, static_cast<uint16_t>(this->localPort));
MS_DUMP_CLEAN(indentation, " closed: %s", this->closed ? "yes" : "no");
MS_DUMP_CLEAN(indentation, "</UdpSocketHandle>");
}
void UdpSocketHandle::Send(
const uint8_t* data, size_t len, const struct sockaddr* addr, UdpSocketHandle::onSendCallback* cb)
{
MS_TRACE();
if (this->closed)
{
if (cb)
{
(*cb)(false);
delete cb;
}
return;
}
if (len == 0)
{
if (cb)
{
(*cb)(false);
delete cb;
}
return;
}
#ifdef MS_LIBURING_SUPPORTED
if (DepLibUring::IsEnabled())
{
if (!DepLibUring::IsActive())
{
goto send_libuv;
}
auto prepared = DepLibUring::PrepareSend(this->fd, data, len, addr, cb);
if (!prepared)
{
MS_DEBUG_DEV("cannot send via liburing, fallback to libuv");
goto send_libuv;
}
return;
}
send_libuv:
#endif
uv_buf_t buffer = uv_buf_init(reinterpret_cast<char*>(const_cast<uint8_t*>(data)), len);
const int sent = uv_udp_try_send(this->uvHandle, &buffer, 1, addr);
if (sent == static_cast<int>(len))
{
this->sentBytes += sent;
if (cb)
{
(*cb)(true);
delete cb;
}
return;
}
else if (sent >= 0)
{
MS_WARN_DEV("datagram truncated (just %d of %zu bytes were sent)", sent, len);
this->sentBytes += sent;
if (cb)
{
(*cb)(false);
delete cb;
}
return;
}
else if (sent != UV_EAGAIN)
{
MS_WARN_DEV("uv_udp_try_send() failed, trying uv_udp_send(): %s", uv_strerror(sent));
}
auto* sendData = new UvSendData(len);
sendData->req.data = static_cast<void*>(sendData);
std::memcpy(sendData->store, data, len);
sendData->cb = cb;
buffer = uv_buf_init(reinterpret_cast<char*>(sendData->store), len);
const int err = uv_udp_send(
&sendData->req, this->uvHandle, &buffer, 1, addr, static_cast<uv_udp_send_cb>(onSend));
if (err != 0)
{
MS_WARN_DEV("uv_udp_send() failed: %s", uv_strerror(err));
if (cb)
{
(*cb)(false);
}
delete sendData;
}
else
{
this->sentBytes += len;
}
}
uint32_t UdpSocketHandle::GetSendBufferSize() const
{
MS_TRACE();
int size{ 0 };
const int err =
uv_send_buffer_size(reinterpret_cast<uv_handle_t*>(this->uvHandle), std::addressof(size));
if (err)
{
MS_THROW_ERROR("uv_send_buffer_size() failed: %s", uv_strerror(err));
}
return static_cast<uint32_t>(size);
}
void UdpSocketHandle::SetSendBufferSize(uint32_t size)
{
MS_TRACE();
auto sizeInt = static_cast<int>(size);
if (sizeInt <= 0)
{
MS_THROW_TYPE_ERROR("invalid size: %d", sizeInt);
}
const int err =
uv_send_buffer_size(reinterpret_cast<uv_handle_t*>(this->uvHandle), std::addressof(sizeInt));
if (err)
{
MS_THROW_ERROR("uv_send_buffer_size() failed: %s", uv_strerror(err));
}
}
uint32_t UdpSocketHandle::GetRecvBufferSize() const
{
MS_TRACE();
int size{ 0 };
const int err =
uv_recv_buffer_size(reinterpret_cast<uv_handle_t*>(this->uvHandle), std::addressof(size));
if (err)
{
MS_THROW_ERROR("uv_recv_buffer_size() failed: %s", uv_strerror(err));
}
return static_cast<uint32_t>(size);
}
void UdpSocketHandle::SetRecvBufferSize(uint32_t size)
{
MS_TRACE();
auto sizeInt = static_cast<int>(size);
if (sizeInt <= 0)
{
MS_THROW_TYPE_ERROR("invalid size: %d", sizeInt);
}
const int err =
uv_recv_buffer_size(reinterpret_cast<uv_handle_t*>(this->uvHandle), std::addressof(sizeInt));
if (err)
{
MS_THROW_ERROR("uv_recv_buffer_size() failed: %s", uv_strerror(err));
}
}
void UdpSocketHandle::InternalClose()
{
MS_TRACE();
if (this->closed)
{
return;
}
this->closed = true;
this->uvHandle->data = nullptr;
const int err = uv_udp_recv_stop(this->uvHandle);
if (err != 0)
{
MS_ABORT("uv_udp_recv_stop() failed: %s", uv_strerror(err));
}
uv_close(reinterpret_cast<uv_handle_t*>(this->uvHandle), static_cast<uv_close_cb>(onCloseUdp));
}
bool UdpSocketHandle::SetLocalAddress()
{
MS_TRACE();
int err;
int len = sizeof(this->localAddr);
err =
uv_udp_getsockname(this->uvHandle, reinterpret_cast<struct sockaddr*>(&this->localAddr), &len);
if (err != 0)
{
MS_ERROR("uv_udp_getsockname() failed: %s", uv_strerror(err));
return false;
}
int family;
Utils::IP::GetAddressInfo(
reinterpret_cast<const struct sockaddr*>(&this->localAddr), family, this->localIp, this->localPort);
return true;
}
inline void UdpSocketHandle::OnUvRecvAlloc(size_t , uv_buf_t* buf)
{
MS_TRACE();
buf->base = reinterpret_cast<char*>(ReadBuffer);
buf->len = ReadBufferSize;
}
inline void UdpSocketHandle::OnUvRecv(
ssize_t nread, const uv_buf_t* buf, const struct sockaddr* addr, unsigned int flags)
{
MS_TRACE();
if (nread == 0)
{
return;
}
if ((flags & UV_UDP_PARTIAL) != 0u)
{
MS_ERROR("received datagram was truncated due to insufficient buffer, ignoring it");
return;
}
if (nread > 0)
{
this->recvBytes += nread;
UserOnUdpDatagramReceived(reinterpret_cast<uint8_t*>(buf->base), nread, ReadBufferSize, addr);
}
else
{
MS_DEBUG_DEV("read error: %s", uv_strerror(nread));
}
}
inline void UdpSocketHandle::OnUvSend(int status, UdpSocketHandle::onSendCallback* cb)
{
MS_TRACE();
if (status == 0)
{
if (cb)
{
(*cb)(true);
}
}
else
{
#if MS_LOG_DEV_LEVEL == 3
MS_DEBUG_DEV("send error: %s", uv_strerror(status));
#endif
if (cb)
{
(*cb)(false);
}
}
}