#include <iostream>
#include <boost/asio/io_context.hpp>
#include <boost/asio/ip/host_name.hpp>
#include <boost/asio/placeholders.hpp>
#include <boost/asio/ip/tcp.hpp>
#include <boost/asio/read_until.hpp>
#include <boost/asio/write.hpp>
#include <boost/algorithm/string/trim.hpp>
#include <boost/algorithm/string/predicate.hpp>
#include <boost/algorithm/string/split.hpp>
#include <boost/algorithm/string/case_conv.hpp>
#include <boost/uuid/random_generator.hpp>
#include <boost/uuid/uuid_io.hpp>
#include <boost/thread/thread_only.hpp>
#include "cast.h"
#include "consumer_queue.h"
#include "sample.h"
#include "send_buffer.h"
#include "socket_utils.h"
#include "stream_info_impl.h"
#include "tcp_server.h"
#define NO_EXPLICIT_TEMPLATE_INSTANTIATION
#include "portable_archive/portable_oarchive.hpp"
using namespace lsl;
using namespace lslboost::asio;
tcp_server::tcp_server(const stream_info_impl_p &info, const io_context_p &io,
const send_buffer_p &sendbuf, const factory_p &factory, tcp protocol, int chunk_size)
: chunk_size_(chunk_size), shutdown_(false), info_(info), io_(io), factory_(factory),
send_buffer_(sendbuf), acceptor_(new tcp::acceptor(*io)) {
acceptor_->open(protocol);
uint16_t port = bind_and_listen_to_port_in_range(*acceptor_,protocol,10);
info_->session_id(api_config::get_instance()->session_id());
info_->uid(lslboost::uuids::to_string(lslboost::uuids::random_generator()()));
info_->created_at(lsl_clock());
info_->hostname(ip::host_name());
if (protocol == tcp::v4())
info_->v4data_port(port);
else
info_->v6data_port(port);
}
void tcp_server::begin_serving() {
shortinfo_msg_ = info_->to_shortinfo_message();
fullinfo_msg_ = info_->to_fullinfo_message();
accept_next_connection();
}
void tcp_server::end_serving() {
shutdown_ = true;
post(*io_, lslboost::bind(&tcp::acceptor::close, acceptor_));
close_inflight_sockets();
send_buffer_->push_sample(factory_->new_sample(lsl_clock(), true));
}
void tcp_server::accept_next_connection() {
try {
client_session_p newsession(new client_session(shared_from_this()));
acceptor_->async_accept(*newsession->socket(),
lslboost::bind(&tcp_server::handle_accept_outcome,shared_from_this(),newsession,placeholders::error));
} catch(std::exception &e) {
std::cerr << "Error during tcp_server::accept_next_connection (id: " << lslboost::this_thread::get_id() << "): " << e.what() << std::endl;
}
}
void tcp_server::handle_accept_outcome(client_session_p newsession, error_code err) {
if (err != error::operation_aborted && err != error::shut_down && !shutdown_) {
if (!err)
newsession->begin_processing();
accept_next_connection();
}
}
void tcp_server::register_inflight_socket(const tcp_socket_p &sock) {
lslboost::lock_guard<lslboost::recursive_mutex> lock(inflight_mut_);
inflight_.insert(sock);
}
void tcp_server::unregister_inflight_socket(const tcp_socket_p &sock) {
lslboost::lock_guard<lslboost::recursive_mutex> lock(inflight_mut_);
inflight_.erase(sock);
}
template<class SocketPtr, class Protocol> void shutdown_and_close(SocketPtr sock) {
try {
if (sock->is_open()) {
try {
sock->shutdown(Protocol::socket::shutdown_both);
} catch(...) {}
sock->close();
}
} catch(std::exception &e) {
std::cerr << "Error during shutdown_and_close (thread id: " << lslboost::this_thread::get_id() << "): " << e.what() << std::endl;
}
}
void tcp_server::close_inflight_sockets() {
lslboost::lock_guard<lslboost::recursive_mutex> lock(inflight_mut_);
for (std::set<tcp_socket_p>::iterator i=inflight_.begin(); i!=inflight_.end(); i++)
post(*io_, lslboost::bind(&shutdown_and_close<tcp_socket_p, tcp>, *i));
}
tcp_server::client_session::client_session(const tcp_server_p &serv):
registered_(false), io_(serv->io_), serv_(serv), sock_(tcp_socket_p(new tcp::socket(*serv->io_))),
requeststream_(&requestbuf_), use_byte_order_(0), data_protocol_version_(100) { }
tcp_server::client_session::~client_session() {
try {
if (registered_)
serv_->unregister_inflight_socket(sock_);
}
catch(std::exception &e) {
std::cerr << "Unexpected error in client_session destructor (id: " << lslboost::this_thread::get_id() << "): " << e.what() << std::endl;
}
catch(...) {
std::cerr << "Severe error during client session shutdown." << std::endl;
}
}
tcp_socket_p tcp_server::client_session::socket() { return sock_; }
void tcp_server::client_session::begin_processing() {
try {
sock_->set_option(lslboost::asio::ip::tcp::no_delay(true));
serv_->register_inflight_socket(sock_);
registered_ = true;
async_read_until(*sock_, requestbuf_, "\r\n",
lslboost::bind(&client_session::handle_read_command_outcome,shared_from_this(),placeholders::error));
} catch(std::exception &e) {
std::cerr << "Error during client_session::begin_processing (id: " << lslboost::this_thread::get_id() << "): " << e.what() << std::endl;
}
}
void tcp_server::client_session::handle_read_command_outcome(error_code err) {
try {
if (!err) {
std::string method; getline(requeststream_,method); lslboost::trim(method);
if (method == "LSL:shortinfo")
async_read_until(*sock_, requestbuf_, "\r\n",
lslboost::bind(&client_session::handle_read_query_outcome,shared_from_this(),placeholders::error));
if (method == "LSL:fullinfo")
async_write(*sock_, lslboost::asio::buffer(serv_->fullinfo_msg_),
lslboost::bind(&client_session::handle_send_outcome,shared_from_this(),placeholders::error));
if (method == "LSL:streamfeed")
async_read_until(*sock_, requestbuf_, "\r\n",
lslboost::bind(&client_session::handle_read_feedparams,shared_from_this(),100,"",placeholders::error));
if (lslboost::algorithm::starts_with(method,"LSL:streamfeed/")) {
std::vector<std::string> parts; lslboost::algorithm::split(parts,method,lslboost::algorithm::is_any_of(" \t"));
int request_protocol_version =
from_string<int>(parts[0].substr(parts[0].find_first_of('/') + 1));
std::string request_uid = (parts.size()>1) ? parts[1] : "";
async_read_until(*sock_, requestbuf_, "\r\n\r\n",
lslboost::bind(&client_session::handle_read_feedparams,shared_from_this(),request_protocol_version,request_uid,placeholders::error));
}
}
} catch(std::exception &e) {
std::cerr << "Unexpected error while parsing a client command (id: " << lslboost::this_thread::get_id() << "): " << e.what() << std::endl;
}
}
void tcp_server::client_session::handle_read_query_outcome(error_code err) {
try {
if (!err) {
std::string query; getline(requeststream_,query); lslboost::trim(query);
if (serv_->info_->matches_query(query))
async_write(*sock_, lslboost::asio::buffer(serv_->shortinfo_msg_),
lslboost::bind(&client_session::handle_send_outcome,shared_from_this(),placeholders::error));
}
} catch(std::exception &e) {
std::cerr << "Unexpected error while parsing a client request (id: " << lslboost::this_thread::get_id() << "): " << e.what() << std::endl;
}
}
void tcp_server::client_session::handle_send_outcome(error_code err) { }
void tcp_server::client_session::send_status_message(const std::string &str) {
string_p msg(new std::string(str));
async_write(*sock_, lslboost::asio::buffer(*msg),
lslboost::bind(&client_session::handle_status_outcome,shared_from_this(),msg,placeholders::error));
}
void tcp_server::client_session::handle_status_outcome(string_p msg, error_code err) { }
void tcp_server::client_session::handle_read_feedparams(int request_protocol_version, std::string request_uid, error_code err) {
try {
if (!err) {
using namespace lslboost::algorithm;
if (request_protocol_version/100 > api_config::get_instance()->use_protocol_version()/100) {
send_status_message("LSL/" +
to_string(api_config::get_instance()->use_protocol_version()) +
" 505 Version not supported");
return;
}
if (!request_uid.empty() && request_uid != serv_->info_->uid()) {
send_status_message("LSL/" +
to_string(api_config::get_instance()->use_protocol_version()) +
" 404 Not found");
return;
}
if (request_protocol_version >= 110) {
int client_byte_order = 1234; double client_endian_performance = 0; bool client_has_ieee754_floats = true; bool client_supports_subnormals = true; int client_protocol_version = request_protocol_version; int client_value_size = serv_->info_->channel_bytes(); lsl_channel_format_t format = serv_->info_->channel_format();
char buf[16384] = {0};
while (requeststream_.getline(buf,sizeof(buf)) && (buf[0] != '\r')) {
std::string hdrline(buf);
std::size_t colon = hdrline.find_first_of(':');
if (colon != std::string::npos) {
std::string type = to_lower_copy(trim_copy(hdrline.substr(0,colon))), rest = to_lower_copy(trim_copy(hdrline.substr(colon+1)));
std::size_t semicolon = rest.find_first_of(';');
if (semicolon != std::string::npos)
rest = rest.substr(0,semicolon);
if (type == "native-byte-order")
client_byte_order = from_string<int>(rest);
if (type == "endian-performance")
client_endian_performance = from_string<double>(rest);
if (type == "has-ieee754-floats")
client_has_ieee754_floats = from_string<bool>(rest);
if (type == "supports-subnormals")
client_supports_subnormals = from_string<bool>(rest);
if (type == "value-size")
client_value_size = from_string<int>(rest);
if (type == "max-buffer-length")
max_buffered_ = from_string<int>(rest);
if (type == "max-chunk-length")
chunk_granularity_ = from_string<int>(rest);
if (type == "protocol-version")
client_protocol_version = from_string<int>(rest);
}
}
bool client_suppress_subnormals = false;
data_protocol_version_ = std::min(api_config::get_instance()->use_protocol_version(),client_protocol_version);
if (serv_->info_->channel_bytes() != client_value_size)
data_protocol_version_ = 100;
if (!format_ieee754[cft_double64] || (format==cft_float32 && !format_ieee754[cft_float32]) || !client_has_ieee754_floats)
data_protocol_version_ = 100;
if (data_protocol_version_ >= 110) {
if (BOOST_BYTE_ORDER != client_byte_order) {
if (client_byte_order == 2134 && client_value_size>=8) {
use_byte_order_ = BOOST_BYTE_ORDER;
} else {
use_byte_order_ = (client_value_size<=1 || (measure_endian_performance()>client_endian_performance)) ? client_byte_order : BOOST_BYTE_ORDER;
}
} else
use_byte_order_ = BOOST_BYTE_ORDER;
client_suppress_subnormals = (format_subnormal[format] && !client_supports_subnormals);
}
std::ostream response_stream(&feedbuf_);
response_stream << "LSL/" << api_config::get_instance()->use_protocol_version() << " 200 OK\r\n";
response_stream << "UID: " << serv_->info_->uid() << "\r\n";
response_stream << "Byte-Order: " << use_byte_order_ << "\r\n";
response_stream << "Suppress-Subnormals: " << client_suppress_subnormals << "\r\n";
response_stream << "Data-Protocol-Version: " << data_protocol_version_ << "\r\n";
response_stream << "\r\n" << std::flush;
} else {
requeststream_ >> max_buffered_ >> chunk_granularity_;
}
if (data_protocol_version_ == 100) {
outarch_.reset(new eos::portable_oarchive(feedbuf_));
*outarch_ << serv_->shortinfo_msg_;
} else {
scratch_.reset(new char[format_sizes[serv_->info_->channel_format()]*serv_->info_->channel_count()]);
}
lslboost::scoped_ptr<sample> temp(factory::new_sample_unmanaged(
serv_->info_->channel_format(), serv_->info_->channel_count(), 0.0, false));
temp->assign_test_pattern(4);
if (data_protocol_version_ >= 110)
temp->save_streambuf(feedbuf_,data_protocol_version_,use_byte_order_,scratch_.get());
else
*outarch_ << *temp;
temp->assign_test_pattern(2);
if (data_protocol_version_ >= 110)
temp->save_streambuf(feedbuf_,data_protocol_version_,use_byte_order_,scratch_.get());
else
*outarch_ << *temp;
async_write(*sock_,feedbuf_.data(),
lslboost::bind(&client_session::handle_send_feedheader_outcome,shared_from_this(),placeholders::error,placeholders::bytes_transferred));
}
} catch(std::exception &e) {
std::cerr << "Unexpected error while serializing the feed header (id: " << lslboost::this_thread::get_id() << "): " << e.what() << std::endl;
}
}
void tcp_server::client_session::handle_send_feedheader_outcome(error_code err, std::size_t n) {
try {
if (!err) {
feedbuf_.consume(n);
work_.reset(new work_p::element_type(serv_->io_->get_executor()));
lslboost::thread(&client_session::transfer_samples_thread,this,shared_from_this());
}
} catch(std::exception &e) {
std::cerr << "Unexpected error while handling the feedheader send outcome (id: " << lslboost::this_thread::get_id() << "): " << e.what() << std::endl;
}
}
void tcp_server::client_session::transfer_samples_thread(client_session_p) {
if (max_buffered_ <= 0)
return;
try {
consumer_queue_p queue = serv_->send_buffer_->new_consumer(max_buffered_);
uint32_t seqn = 0;
while (!serv_->shutdown_) {
try {
sample_p samp(queue->pop_sample());
if (serv_->shutdown_)
break;
if (!samp)
continue;
if (chunk_granularity_)
samp->pushthrough = (((++seqn) % (uint32_t)chunk_granularity_) == 0);
else
if (serv_->chunk_size_)
samp->pushthrough = (((++seqn) % (uint32_t)serv_->chunk_size_) == 0);
if (data_protocol_version_ >= 110)
samp->save_streambuf(feedbuf_,data_protocol_version_,use_byte_order_,scratch_.get());
else
*outarch_ << *samp;
if (samp->pushthrough) {
lslboost::unique_lock<lslboost::mutex> lock(completion_mut_);
transfer_completed_ = false;
async_write(*sock_,feedbuf_.data(),
lslboost::bind(&client_session::handle_chunk_transfer_outcome,shared_from_this(),placeholders::error,placeholders::bytes_transferred));
completion_cond_.wait(lock, lslboost::bind(&client_session::transfer_completed,this));
if (!transfer_error_) {
feedbuf_.consume(transfer_amount_);
} else
break;
}
} catch(std::exception &e) {
std::cerr << "Unexpected glitch in transfer_samples_thread (id: " << lslboost::this_thread::get_id() << "): " << e.what() << std::endl;
}
}
} catch(std::exception &e) {
std::cerr << "Unexpected error in transfer_samples_thread (id: " << lslboost::this_thread::get_id() << "): " << e.what() << "; exiting..." << std::endl;
}
}
void tcp_server::client_session::handle_chunk_transfer_outcome(error_code err, std::size_t len) {
try {
{
lslboost::lock_guard<lslboost::mutex> lock(completion_mut_);
transfer_error_ = err;
transfer_amount_ = len;
transfer_completed_ = true;
}
completion_cond_.notify_all();
} catch(std::exception &e) {
std::cerr << "Catastrophic error in handling the chunk transfer outcome (in tcp_server): " << e.what() << std::endl;
}
}