#include <gtest/gtest.h>
#include <chrono>
#ifdef _WIN32
#define INC_SRT_WIN_WINTIME
#else
typedef int SOCKET;
#define INVALID_SOCKET ((SOCKET)-1)
#define closesocket close
#endif
#include"platform_sys.h"
#include "srt.h"
using namespace std;
class TestConnectionTimeout
: public ::testing::Test
{
protected:
TestConnectionTimeout()
{
}
~TestConnectionTimeout()
{
}
protected:
void SetUp() override
{
ASSERT_EQ(srt_startup(), 0);
m_sa.sin_family = AF_INET;
m_sa.sin_addr.s_addr = INADDR_ANY;
m_udp_sock = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP);
ASSERT_NE(m_udp_sock, -1);
int bind_res = -1;
const sockaddr* psa = reinterpret_cast<const sockaddr*>(&m_sa);
for (int port = 5000; port <= 5555; ++port)
{
m_sa.sin_port = htons(port);
bind_res = ::bind(m_udp_sock, psa, sizeof m_sa);
if (bind_res >= 0)
{
cerr << "Running test on port " << port << "\n";
break;
}
}
ASSERT_GE(bind_res, 0);
ASSERT_EQ(inet_pton(AF_INET, "127.0.0.1", &m_sa.sin_addr), 1);
}
void TearDown() override
{
ASSERT_NE(closesocket(m_udp_sock), -1);
srt_cleanup();
}
protected:
SOCKET m_udp_sock = INVALID_SOCKET;
sockaddr_in m_sa = sockaddr_in();
};
TEST_F(TestConnectionTimeout, Nonblocking) {
const SRTSOCKET client_sock = srt_create_socket();
ASSERT_GT(client_sock, 0);
int conn_timeout = 0;
int conn_timeout_len = sizeof conn_timeout;
EXPECT_EQ(srt_getsockopt(client_sock, 0, SRTO_CONNTIMEO, &conn_timeout, &conn_timeout_len), SRT_SUCCESS);
EXPECT_EQ(conn_timeout, 3000);
const int connection_timeout_ms = 500;
EXPECT_EQ(srt_setsockopt(client_sock, 0, SRTO_CONNTIMEO, &connection_timeout_ms, sizeof connection_timeout_ms), SRT_SUCCESS);
const int yes = 1;
const int no = 0;
ASSERT_EQ(srt_setsockopt(client_sock, 0, SRTO_RCVSYN, &no, sizeof no), SRT_SUCCESS); ASSERT_EQ(srt_setsockopt(client_sock, 0, SRTO_SNDSYN, &no, sizeof no), SRT_SUCCESS); ASSERT_EQ(srt_setsockopt(client_sock, 0, SRTO_TSBPDMODE, &yes, sizeof yes), SRT_SUCCESS);
ASSERT_EQ(srt_setsockflag(client_sock, SRTO_SENDER, &yes, sizeof yes), SRT_SUCCESS);
const int pollid = srt_epoll_create();
ASSERT_GE(pollid, 0);
const int epoll_out = SRT_EPOLL_OUT | SRT_EPOLL_ERR;
ASSERT_NE(srt_epoll_add_usock(pollid, client_sock, &epoll_out), SRT_ERROR);
const sockaddr* psa = reinterpret_cast<const sockaddr*>(&m_sa);
ASSERT_NE(srt_connect(client_sock, psa, sizeof m_sa), SRT_ERROR);
{
int rlen = 2;
SRTSOCKET read[2];
int wlen = 2;
SRTSOCKET write[2];
using namespace std;
const chrono::steady_clock::time_point chrono_ts_start = chrono::steady_clock::now();
EXPECT_EQ(srt_epoll_wait(pollid, read, &rlen,
write, &wlen,
connection_timeout_ms + 100, 0, 0, 0, 0)
, 2);
const chrono::steady_clock::time_point chrono_ts_end = chrono::steady_clock::now();
const auto delta_ms = chrono::duration_cast<chrono::milliseconds>(chrono_ts_end - chrono_ts_start).count();
EXPECT_LE(delta_ms, connection_timeout_ms + 80);
EXPECT_GE(delta_ms, connection_timeout_ms - 80);
cerr << "Timeout was: " << delta_ms << "\n";
EXPECT_EQ(rlen, 1);
EXPECT_EQ(read[0], client_sock);
EXPECT_EQ(wlen, 1);
EXPECT_EQ(write[0], client_sock);
}
EXPECT_EQ(srt_epoll_remove_usock(pollid, client_sock), SRT_SUCCESS);
EXPECT_EQ(srt_close(client_sock), SRT_SUCCESS);
(void)srt_epoll_release(pollid);
}
TEST_F(TestConnectionTimeout, BlockingLoop)
{
const SRTSOCKET client_sock = srt_create_socket();
ASSERT_GT(client_sock, 0);
const int connection_timeout_ms = 999;
EXPECT_EQ(srt_setsockopt(client_sock, 0, SRTO_CONNTIMEO, &connection_timeout_ms, sizeof connection_timeout_ms), SRT_SUCCESS);
const sockaddr* psa = reinterpret_cast<const sockaddr*>(&m_sa);
for (int i = 0; i < 10; ++i)
{
EXPECT_EQ(srt_connect(client_sock, psa, sizeof m_sa), SRT_ERROR);
const int error_code = srt_getlasterror(nullptr);
EXPECT_EQ(error_code, SRT_ENOSERVER);
if (error_code != SRT_ENOSERVER)
{
cerr << "Connection attempt no. " << i << " resulted with: "
<< error_code << " " << srt_getlasterror_str() << "\n";
break;
}
}
EXPECT_EQ(srt_close(client_sock), SRT_SUCCESS);
}