#ifdef _MSC_VER
#define _CRT_SECURE_NO_WARNINGS 1
#include <io.h>
#endif
#include "platform_sys.h"
#define REQUIRE_CXX11 1
#include <cctype>
#include <iostream>
#include <fstream>
#include <string>
#include <map>
#include <set>
#include <vector>
#include <deque>
#include <memory>
#include <algorithm>
#include <iterator>
#include <stdexcept>
#include <cstring>
#include <csignal>
#include <chrono>
#include <thread>
#include <mutex>
#include <condition_variable>
#include "apputil.hpp"
#include "uriparser.hpp"
#include "socketoptions.hpp"
#include "logsupport.hpp"
#include "transmitbase.hpp"
#include "verbose.hpp"
#include <srt.h>
#include <udt.h>
#include <logging.h>
#include <api.h>
#include <utilities.h>
using namespace std;
const srt_logging::LogFA SRT_LOGFA_APP = 10;
namespace srt_logging
{
Logger applog(SRT_LOGFA_APP, srt_logger_config, "TUNNELAPP");
}
using srt_logging::applog;
class Medium
{
static int s_counter;
int m_counter;
public:
enum ReadStatus
{
RD_DATA, RD_AGAIN, RD_EOF, RD_ERROR
};
enum Mode
{
LISTENER, CALLER
};
protected:
UriParser m_uri;
size_t m_chunk = 0;
map<string, string> m_options;
Mode m_mode;
bool m_listener = false;
bool m_open = false;
bool m_eof = false;
bool m_broken = false;
mutex access;
template <class DerivedMedium, class SocketType>
static Medium* CreateAcceptor(DerivedMedium* self, const sockaddr_any& sa, SocketType sock, size_t chunk)
{
string addr = sockaddr_any(sa.get(), sizeof sa).str();
DerivedMedium* m = new DerivedMedium(UriParser(self->type() + string("://") + addr), chunk);
m->m_socket = sock;
return m;
}
public:
string uri() { return m_uri.uri(); }
string id()
{
std::ostringstream os;
os << type() << m_counter;
return os.str();
}
Medium(UriParser u, size_t ch): m_counter(s_counter++), m_uri(u), m_chunk(ch) {}
Medium(): m_counter(s_counter++) {}
virtual const char* type() = 0;
virtual bool IsOpen() = 0;
virtual void CloseInternal() = 0;
void Close()
{
m_open = false;
m_broken = true;
CloseInternal();
}
virtual bool End() = 0;
virtual int ReadInternal(char* output, int size) = 0;
virtual bool IsErrorAgain() = 0;
ReadStatus Read(bytevector& output);
virtual void Write(bytevector& portion) = 0;
virtual void CreateListener() = 0;
virtual void CreateCaller() = 0;
virtual unique_ptr<Medium> Accept() = 0;
virtual void Connect() = 0;
static std::unique_ptr<Medium> Create(const std::string& url, size_t chunk, Mode);
virtual bool Broken() = 0;
virtual size_t Still() { return 0; }
class ReadEOF: public std::runtime_error
{
public:
ReadEOF(const std::string& fn): std::runtime_error( "EOF while reading file: " + fn )
{
}
};
class TransmissionError: public std::runtime_error
{
public:
TransmissionError(const std::string& fn): std::runtime_error( fn )
{
}
};
static void Error(const string& text)
{
throw TransmissionError("ERROR (internal): " + text);
}
virtual ~Medium()
{
}
protected:
void InitMode(Mode m)
{
m_mode = m;
Init();
if (m_mode == LISTENER)
{
CreateListener();
m_listener = true;
}
else
{
CreateCaller();
}
m_open = true;
}
virtual void Init() {}
};
class Engine
{
Medium* media[2];
std::thread thr;
class Tunnel* parent_tunnel;
std::string nameid;
int status = 0;
Medium::ReadStatus rdst = Medium::RD_ERROR;
UDT::ERRORINFO srtx;
public:
enum Dir { DIR_IN, DIR_OUT };
int stat() { return status; }
Engine(Tunnel* p, Medium* m1, Medium* m2, const std::string& nid)
:
#ifdef HAVE_FULL_CXX11
media {m1, m2},
#endif
parent_tunnel(p), nameid(nid)
{
#ifndef HAVE_FULL_CXX11
media[0] = m1;
media[1] = m2;
#endif
}
void Start()
{
Verb() << "START: " << media[DIR_IN]->uri() << " --> " << media[DIR_OUT]->uri();
std::string thrn = media[DIR_IN]->id() + ">" + media[DIR_OUT]->id();
ThreadName tn(thrn.c_str());
thr = thread([this]() { Worker(); });
}
void Stop()
{
if (thr.joinable())
{
LOGP(applog.Debug, "Engine::Stop: Closing media:");
media[0]->Close();
media[1]->Close();
LOGP(applog.Debug, "Engine::Stop: media closed, joining engine thread:");
if (thr.get_id() == std::this_thread::get_id())
{
thr.detach();
LOGP(applog.Debug, "DETACHED.");
}
else
{
thr.join();
LOGP(applog.Debug, "Joined.");
}
}
}
void Worker();
};
struct Tunnelbox;
class Tunnel
{
Tunnelbox* parent_box;
std::unique_ptr<Medium> med_acp, med_clr;
Engine acp_to_clr, clr_to_acp;
volatile bool running = true;
mutex access;
public:
string show()
{
return med_acp->uri() + " <-> " + med_clr->uri();
}
Tunnel(Tunnelbox* m, std::unique_ptr<Medium>&& acp, std::unique_ptr<Medium>&& clr):
parent_box(m),
med_acp(move(acp)), med_clr(move(clr)),
acp_to_clr(this, med_acp.get(), med_clr.get(), med_acp->id() + ">" + med_clr->id()),
clr_to_acp(this, med_clr.get(), med_acp.get(), med_clr->id() + ">" + med_acp->id())
{
}
void Start()
{
acp_to_clr.Start();
clr_to_acp.Start();
}
void decommission_engine(Medium* which_medium)
{
Verb() << "Medium broken: " << which_medium->uri();
bool stop = true;
if (stop)
{
med_acp->Close();
med_clr->Close();
Stop();
}
}
void Stop();
bool decommission_if_dead(bool forced); };
void Engine::Worker()
{
bytevector outbuf;
Medium* which_medium = media[DIR_IN];
for (;;)
{
try
{
which_medium = media[DIR_IN];
rdst = media[DIR_IN]->Read((outbuf));
switch (rdst)
{
case Medium::RD_DATA:
{
which_medium = media[DIR_OUT];
media[DIR_OUT]->Write((outbuf));
}
break;
case Medium::RD_EOF:
status = -1;
throw Medium::ReadEOF("");
case Medium::RD_AGAIN:
case Medium::RD_ERROR:
status = -1;
Medium::Error("Error while reading");
}
}
catch (Medium::ReadEOF&)
{
Verb() << "EOF. Exiting engine.";
break;
}
catch (Medium::TransmissionError& er)
{
Verb() << er.what() << " - interrupting engine: " << nameid;
break;
}
}
parent_tunnel->decommission_engine(which_medium);
}
class SrtMedium: public Medium
{
SRTSOCKET m_socket = SRT_ERROR;
friend class Medium;
public:
#ifdef HAVE_FULL_CXX11
using Medium::Medium;
#else
SrtMedium(UriParser u, size_t ch): Medium(u, ch) {}
#endif
bool IsOpen() override { return m_open; }
bool End() override { return m_eof; }
bool Broken() override { return m_broken; }
void CloseInternal() override
{
Verb() << "Closing SRT socket for " << uri();
lock_guard<mutex> lk(access);
if (m_socket == SRT_ERROR)
return;
srt_close(m_socket);
m_socket = SRT_ERROR;
}
virtual const char* type() override { return "srt"; }
virtual int ReadInternal(char* output, int size) override;
virtual bool IsErrorAgain() override;
virtual void Write(bytevector& portion) override;
virtual void CreateListener() override;
virtual void CreateCaller() override;
virtual unique_ptr<Medium> Accept() override;
virtual void Connect() override;
protected:
virtual void Init() override;
void ConfigurePre();
void ConfigurePost(SRTSOCKET socket);
using Medium::Error;
static void Error(UDT::ERRORINFO& ri, const string& text)
{
throw TransmissionError("ERROR: " + text + ": " + ri.getErrorMessage());
}
virtual ~SrtMedium() override
{
Close();
}
};
class TcpMedium: public Medium
{
int m_socket = -1;
friend class Medium;
public:
#ifdef HAVE_FULL_CXX11
using Medium::Medium;
#else
TcpMedium(UriParser u, size_t ch): Medium(u, ch) {}
#endif
#ifdef _WIN32
static int tcp_close(int socket)
{
return ::closesocket(socket);
}
enum { DEF_SEND_FLAG = 0 };
#elif defined(LINUX) || defined(GNU) || defined(CYGWIN)
static int tcp_close(int socket)
{
return ::close(socket);
}
enum { DEF_SEND_FLAG = MSG_NOSIGNAL };
#else
static int tcp_close(int socket)
{
return ::close(socket);
}
enum { DEF_SEND_FLAG = 0 };
#endif
bool IsOpen() override { return m_open; }
bool End() override { return m_eof; }
bool Broken() override { return m_broken; }
void CloseInternal() override
{
Verb() << "Closing TCP socket for " << uri();
lock_guard<mutex> lk(access);
if (m_socket == -1)
return;
tcp_close(m_socket);
m_socket = -1;
}
virtual const char* type() override { return "tcp"; }
virtual int ReadInternal(char* output, int size) override;
virtual bool IsErrorAgain() override;
virtual void Write(bytevector& portion) override;
virtual void CreateListener() override;
virtual void CreateCaller() override;
virtual unique_ptr<Medium> Accept() override;
virtual void Connect() override;
protected:
void ConfigurePre()
{
#if defined(__APPLE__)
int optval = 1;
setsockopt(m_socket, SOL_SOCKET, SO_NOSIGPIPE, &optval, sizeof(optval));
#endif
}
void ConfigurePost(int)
{
}
using Medium::Error;
static void Error(int verrno, const string& text)
{
char rbuf[1024];
throw TransmissionError("ERROR: " + text + ": " + SysStrError(verrno, rbuf, 1024));
}
virtual ~TcpMedium()
{
Close();
}
};
void SrtMedium::Init()
{
if (m_options.count("mode"))
Error("No option 'mode' is required, it defaults to position of the argument");
if (m_options.count("blocking"))
Error("Blocking is not configurable here.");
m_options["transtype"] = "file";
}
void SrtMedium::ConfigurePre()
{
vector<string> fails;
m_options["mode"] = "caller";
SrtConfigurePre(m_socket, "", m_options, &fails);
if (!fails.empty())
{
cerr << "Failed options: " << Printable(fails) << endl;
}
}
void SrtMedium::ConfigurePost(SRTSOCKET so)
{
vector<string> fails;
SrtConfigurePost(so, m_options, &fails);
if (!fails.empty())
{
cerr << "Failed options: " << Printable(fails) << endl;
}
}
void SrtMedium::CreateListener()
{
int backlog = 5;
m_socket = srt_create_socket();
ConfigurePre();
sockaddr_any sa = CreateAddr(m_uri.host(), m_uri.portno());
int stat = srt_bind(m_socket, sa.get(), sizeof sa);
if ( stat == SRT_ERROR )
{
srt_close(m_socket);
Error(UDT::getlasterror(), "srt_bind");
}
stat = srt_listen(m_socket, backlog);
if ( stat == SRT_ERROR )
{
srt_close(m_socket);
Error(UDT::getlasterror(), "srt_listen");
}
m_listener = true;
};
void TcpMedium::CreateListener()
{
int backlog = 5;
sockaddr_any sa = CreateAddr(m_uri.host(), m_uri.portno());
m_socket = socket(sa.get()->sa_family, SOCK_STREAM, IPPROTO_TCP);
ConfigurePre();
int stat = ::bind(m_socket, sa.get(), sa.size());
if (stat == -1)
{
tcp_close(m_socket);
Error(errno, "bind");
}
stat = listen(m_socket, backlog);
if ( stat == -1 )
{
tcp_close(m_socket);
Error(errno, "listen");
}
m_listener = true;
}
unique_ptr<Medium> SrtMedium::Accept()
{
sockaddr_any sa;
SRTSOCKET s = srt_accept(m_socket, (sa.get()), (&sa.len));
if (s == SRT_ERROR)
{
Error(UDT::getlasterror(), "srt_accept");
}
ConfigurePost(s);
int timeout_1s = 1000;
srt_setsockflag(m_socket, SRTO_RCVTIMEO, &timeout_1s, sizeof timeout_1s);
unique_ptr<Medium> med(CreateAcceptor(this, sa, s, m_chunk));
Verb() << "accepted a connection from " << med->uri();
return med;
}
unique_ptr<Medium> TcpMedium::Accept()
{
sockaddr_any sa;
int s = ::accept(m_socket, (sa.get()), (&sa.syslen()));
if (s == -1)
{
Error(errno, "accept");
}
timeval timeout_1s { 1, 0 };
int st = setsockopt(s, SOL_SOCKET, SO_RCVTIMEO, (char*)&timeout_1s, sizeof timeout_1s);
timeval re;
socklen_t size = sizeof re;
int st2 = getsockopt(s, SOL_SOCKET, SO_RCVTIMEO, (char*)&re, &size);
LOGP(applog.Debug, "Setting SO_RCVTIMEO to @", m_socket, ": ", st == -1 ? "FAILED" : "SUCCEEDED",
", read-back value: ", st2 == -1 ? int64_t(-1) : (int64_t(re.tv_sec)*1000000 + re.tv_usec)/1000, "ms");
unique_ptr<Medium> med(CreateAcceptor(this, sa, s, m_chunk));
Verb() << "accepted a connection from " << med->uri();
return med;
}
void SrtMedium::CreateCaller()
{
m_socket = srt_create_socket();
ConfigurePre();
}
void TcpMedium::CreateCaller()
{
m_socket = ::socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
ConfigurePre();
}
void SrtMedium::Connect()
{
sockaddr_any sa = CreateAddr(m_uri.host(), m_uri.portno());
int st = srt_connect(m_socket, sa.get(), sizeof sa);
if (st == SRT_ERROR)
Error(UDT::getlasterror(), "srt_connect");
ConfigurePost(m_socket);
int timeout_1s = 1000;
srt_setsockflag(m_socket, SRTO_RCVTIMEO, &timeout_1s, sizeof timeout_1s);
}
void TcpMedium::Connect()
{
sockaddr_any sa = CreateAddr(m_uri.host(), m_uri.portno());
int st = ::connect(m_socket, sa.get(), sa.size());
if (st == -1)
Error(errno, "connect");
ConfigurePost(m_socket);
timeval timeout_1s { 1, 0 };
setsockopt(m_socket, SOL_SOCKET, SO_RCVTIMEO, (char*)&timeout_1s, sizeof timeout_1s);
}
int SrtMedium::ReadInternal(char* w_buffer, int size)
{
int st = -1;
do
{
st = srt_recv(m_socket, (w_buffer), size);
if (st == SRT_ERROR)
{
int syserr;
if (srt_getlasterror(&syserr) == SRT_EASYNCRCV && !m_broken)
continue;
}
break;
} while (true);
return st;
}
int TcpMedium::ReadInternal(char* w_buffer, int size)
{
int st = -1;
LOGP(applog.Debug, "TcpMedium:recv @", m_socket, " - begin");
do
{
st = ::recv(m_socket, (w_buffer), size, 0);
if (st == -1)
{
if ((errno == EAGAIN || errno == EWOULDBLOCK))
{
if (!m_broken)
{
LOGP(applog.Debug, "TcpMedium: read:AGAIN, repeating");
continue;
}
LOGP(applog.Debug, "TcpMedium: read:AGAIN, not repeating - already broken");
}
else
{
LOGP(applog.Debug, "TcpMedium: read:ERROR: ", errno);
}
}
break;
} while (true);
LOGP(applog.Debug, "TcpMedium:recv @", m_socket, " - result: ", st);
return st;
}
bool SrtMedium::IsErrorAgain()
{
return srt_getlasterror(NULL) == SRT_EASYNCRCV;
}
bool TcpMedium::IsErrorAgain()
{
return errno == EAGAIN;
}
Medium::ReadStatus Medium::Read(bytevector& w_output)
{
if (w_output.size() > m_chunk)
{
Verb() << "BUFFER EXCEEDED";
return RD_DATA;
}
size_t shift = w_output.size();
if (shift && m_eof)
{
return RD_DATA;
}
size_t pred_size = shift + m_chunk;
w_output.resize(pred_size);
int st = ReadInternal((w_output.data() + shift), m_chunk);
if (st == -1)
{
if (IsErrorAgain())
return RD_AGAIN;
return RD_ERROR;
}
if (st == 0)
{
m_eof = true;
if (shift)
{
w_output.resize(shift);
return RD_DATA;
}
w_output.clear();
return RD_EOF;
}
w_output.resize(shift+st);
return RD_DATA;
}
void SrtMedium::Write(bytevector& w_buffer)
{
int st = srt_send(m_socket, w_buffer.data(), w_buffer.size());
if (st == SRT_ERROR)
{
Error(UDT::getlasterror(), "srt_send");
}
if (st >= int(w_buffer.size()))
w_buffer.clear();
else if (st == 0)
{
Error("Unexpected EOF on Write");
}
else
{
w_buffer.erase(w_buffer.begin(), w_buffer.begin()+st);
}
}
void TcpMedium::Write(bytevector& w_buffer)
{
int st = ::send(m_socket, w_buffer.data(), w_buffer.size(), DEF_SEND_FLAG);
if (st == -1)
{
Error(errno, "send");
}
if (st >= int(w_buffer.size()))
w_buffer.clear();
else if (st == 0)
{
Error("Unexpected EOF on Write");
}
else
{
w_buffer.erase(w_buffer.begin(), w_buffer.begin()+st);
}
}
std::unique_ptr<Medium> Medium::Create(const std::string& url, size_t chunk, Medium::Mode mode)
{
UriParser uri(url);
std::unique_ptr<Medium> out;
if (uri.scheme() == "srt")
{
out.reset(new SrtMedium(uri, chunk));
}
else if (uri.scheme() == "tcp")
{
out.reset(new TcpMedium(uri, chunk));
}
else
{
Error("Medium not supported");
}
out->InitMode(mode);
return out;
}
struct Tunnelbox
{
list<unique_ptr<Tunnel>> tunnels;
mutex access;
condition_variable decom_ready;
bool main_running = true;
thread thr;
void signal_decommission()
{
lock_guard<mutex> lk(access);
decom_ready.notify_one();
}
void install(std::unique_ptr<Medium>&& acp, std::unique_ptr<Medium>&& clr)
{
lock_guard<mutex> lk(access);
Verb() << "Tunnelbox: Starting tunnel: " << acp->uri() << " <-> " << clr->uri();
tunnels.emplace_back(new Tunnel(this, move(acp), move(clr)));
auto& it = tunnels.back();
it->Start();
}
void start_cleaner()
{
thr = thread( [this]() { CleanupWorker(); } );
}
void stop_cleaner()
{
if (thr.joinable())
thr.join();
}
private:
void CleanupWorker()
{
unique_lock<mutex> lk(access);
while (main_running)
{
decom_ready.wait(lk);
for (auto i = tunnels.begin(), i_next = i; i != tunnels.end(); i = i_next)
{
++i_next;
if ((*i)->decommission_if_dead(main_running))
{
tunnels.erase(i);
}
}
}
}
};
void Tunnel::Stop()
{
if (!running)
return;
lock_guard<mutex> lk(access);
running = false;
parent_box->signal_decommission();
}
bool Tunnel::decommission_if_dead(bool forced)
{
lock_guard<mutex> lk(access);
if (running && !forced)
return false;
acp_to_clr.Stop();
clr_to_acp.Stop();
return true;
}
int Medium::s_counter = 1;
Tunnelbox g_tunnels;
std::unique_ptr<Medium> main_listener;
size_t default_chunk = 4096;
int OnINT_StopService(int)
{
g_tunnels.main_running = false;
g_tunnels.signal_decommission();
main_listener->Close();
return 0;
}
int main( int argc, char** argv )
{
if (!SysInitializeNetwork())
{
cerr << "Fail to initialize network module.";
return 1;
}
size_t chunk = default_chunk;
OptionName
o_loglevel = { "ll", "loglevel" },
o_logfa = { "lf", "logfa" },
o_chunk = {"c", "chunk" },
o_verbose = {"v", "verbose" },
o_noflush = {"s", "skipflush" };
vector<OptionScheme> optargs = {
{ o_loglevel, OptionScheme::ARG_ONE },
{ o_logfa, OptionScheme::ARG_ONE },
{ o_chunk, OptionScheme::ARG_ONE }
};
options_t params = ProcessOptions(argv, argc, optargs);
vector<string> args = params[""];
if ( args.size() < 2 )
{
cerr << "Usage: " << argv[0] << " <listen-uri> <call-uri>\n";
return 1;
}
string loglevel = Option<OutString>(params, "error", o_loglevel);
string logfa = Option<OutString>(params, "", o_logfa);
srt_logging::LogLevel::type lev = SrtParseLogLevel(loglevel);
UDT::setloglevel(lev);
if (logfa == "")
{
UDT::addlogfa(SRT_LOGFA_APP);
}
else
{
set<string> unknown_fas;
set<srt_logging::LogFA> fas = SrtParseLogFA(logfa, &unknown_fas);
UDT::resetlogfa(fas);
if (unknown_fas.count("app"))
UDT::addlogfa(SRT_LOGFA_APP);
}
string verbo = Option<OutString>(params, "no", o_verbose);
if ( verbo == "" || !false_names.count(verbo) )
{
Verbose::on = true;
Verbose::cverb = &std::cout;
}
string chunks = Option<OutString>(params, "", o_chunk);
if ( chunks!= "" )
{
chunk = stoi(chunks);
}
string listen_node = args[0];
string call_node = args[1];
UriParser ul(listen_node), uc(call_node);
set<string> allowed = {"srt", "tcp"};
if (!allowed.count(ul.scheme())|| !allowed.count(uc.scheme()))
{
cerr << "ERROR: only tcp and srt schemes supported";
return -1;
}
Verb() << "LISTEN type=" << ul.scheme() << ", CALL type=" << uc.scheme();
g_tunnels.start_cleaner();
main_listener = Medium::Create(listen_node, chunk, Medium::LISTENER);
for (;;)
{
try
{
Verb() << "Waiting for connection...";
std::unique_ptr<Medium> accepted = main_listener->Accept();
if (!g_tunnels.main_running)
{
Verb() << "Service stopped. Exiting.";
break;
}
Verb() << "Connection accepted. Connecting to the relay...";
std::unique_ptr<Medium> caller = Medium::Create(call_node, chunk, Medium::CALLER);
caller->Connect();
Verb() << "Connected. Establishing pipe.";
g_tunnels.install(move(accepted), move(caller));
}
catch (...)
{
Verb() << "Connection reported, but failed";
}
}
g_tunnels.stop_cleaner();
return 0;
}