#include <iostream>
#include <boost/bind.hpp>
#include <boost/smart_ptr/scoped_ptr.hpp>
#include <boost/algorithm/string/split.hpp>
#include <boost/algorithm/string/trim.hpp>
#include <boost/algorithm/string/case_conv.hpp>
#include <boost/algorithm/string/predicate.hpp>
#include <boost/type_traits/conditional.hpp>
#include "cancellable_streambuf.h"
#include "data_receiver.h"
#include "sample.h"
#include "socket_utils.h"
#define NO_EXPLICIT_TEMPLATE_INSTANTIATION
#include "portable_archive/portable_iarchive.hpp"
using namespace lslboost::algorithm;
namespace lsl {
data_receiver::data_receiver(inlet_connection &conn, int max_buflen, int max_chunklen): conn_(conn), check_thread_start_(true), closing_stream_(false), connected_(false), sample_queue_(max_buflen),
sample_factory_(new factory(conn.type_info().channel_format(),conn.type_info().channel_count(),conn.type_info().nominal_srate()?conn.type_info().nominal_srate()*api_config::get_instance()->inlet_buffer_reserve_ms()/1000:api_config::get_instance()->inlet_buffer_reserve_samples())), max_buflen_(max_buflen), max_chunklen_(max_chunklen)
{
if (max_buflen < 0)
throw std::invalid_argument("The max_buflen argument must not be smaller than 0.");
if (max_chunklen < 0)
throw std::invalid_argument("The max_chunklen argument must not be smaller than 0.");
conn_.register_onlost(this,&connected_upd_);
}
data_receiver::~data_receiver() {
try {
conn_.unregister_onlost(this);
if (data_thread_.joinable())
data_thread_.join();
}
catch(std::exception &e) {
std::cerr << "Unexpected error during destruction of a data_receiver: " << e.what() << std::endl;
}
catch(...) {
std::cerr << "Severe error during data receiver shutdown." << std::endl;
}
}
void data_receiver::open_stream(double timeout) {
closing_stream_ = false;
lslboost::unique_lock<lslboost::mutex> lock(connected_mut_);
if (!connection_completed()) {
if (check_thread_start_ && !data_thread_.joinable()) {
data_thread_ = lslboost::thread(&data_receiver::data_thread,this);
check_thread_start_ = false;
}
if (timeout >= FOREVER)
connected_upd_.wait(lock, lslboost::bind(&data_receiver::connection_completed,this));
else
if (!connected_upd_.wait_for(lock, lslboost::chrono::duration<double>(timeout), lslboost::bind(&data_receiver::connection_completed,this)))
throw timeout_error("The open_stream() operation timed out.");
}
if (conn_.lost())
throw lost_error("The stream read by this inlet has been lost. To recover, you need to re-resolve the source and re-create the inlet.");
}
void data_receiver::close_stream() {
check_thread_start_ = true;
closing_stream_ = true;
cancel_all_registered();
}
template <class T>
double data_receiver::pull_sample_typed(T *buffer, int buffer_elements, double timeout) {
if (conn_.lost())
throw lost_error("The stream read by this outlet has been lost. To recover, you need to "
"re-resolve the source and re-create the inlet.");
if (check_thread_start_ && !data_thread_.joinable()) {
data_thread_ = lslboost::thread(&data_receiver::data_thread, this);
check_thread_start_ = false;
}
if (sample_p s = sample_queue_.pop_sample(timeout)) {
if (buffer_elements != conn_.type_info().channel_count())
throw std::range_error("The number of buffer elements provided does not match the "
"number of channels in the sample.");
s->retrieve_typed(buffer);
return s->timestamp;
} else {
if (conn_.lost())
throw lost_error("The stream read by this inlet has been lost. To recover, you need to "
"re-resolve the source and re-create the inlet.");
return 0.0;
}
}
typedef lslboost::conditional<sizeof(long) == 8, int64_t, int32_t>::type long_type;
template <>
double data_receiver::pull_sample_typed(long *buffer, int buffer_elements, double timeout) {
return pull_sample_typed((long_type *)buffer, buffer_elements, timeout);
}
template double data_receiver::pull_sample_typed<char>(char *, int, double);
template double data_receiver::pull_sample_typed<int16_t>(int16_t *, int, double);
template double data_receiver::pull_sample_typed<int32_t>(int32_t *, int, double);
template double data_receiver::pull_sample_typed<int64_t>(int64_t *, int, double);
template double data_receiver::pull_sample_typed<float>(float *, int, double);
template double data_receiver::pull_sample_typed<double>(double *, int, double);
template double data_receiver::pull_sample_typed<std::string>(std::string *, int, double);
double data_receiver::pull_sample_untyped(void *buffer, int buffer_bytes, double timeout) {
if (conn_.lost())
throw lost_error("The stream read by this inlet has been lost. To recover, you need to re-resolve the source and re-create the inlet.");
if (check_thread_start_ && !data_thread_.joinable()) {
data_thread_ = lslboost::thread(&data_receiver::data_thread,this);
check_thread_start_ = false;
}
if (sample_p s = sample_queue_.pop_sample(timeout)) {
if (buffer_bytes != conn_.type_info().sample_bytes())
throw std::range_error("The size of the provided buffer does not match the number of bytes in the sample.");
s->retrieve_untyped(buffer);
return s->timestamp;
} else {
if (conn_.lost())
throw lost_error("The stream read by this inlet has been lost. To recover, you need to re-resolve the source and re-create the inlet.");
return 0.0;
}
}
void data_receiver::data_thread() {
conn_.acquire_watchdog();
factory_p factory(sample_factory_);
try {
while (!conn_.lost() && !conn_.shutdown() && !closing_stream_) {
try {
cancellable_streambuf buffer;
buffer.register_at(&conn_);
buffer.register_at(this);
std::iostream server_stream(&buffer);
lslboost::scoped_ptr<eos::portable_iarchive> inarch;
buffer.connect(conn_.get_tcp_endpoint());
if (buffer.puberror())
throw buffer.puberror();
int use_byte_order = 0; int data_protocol_version = 100; bool suppress_subnormals = false;
int proposed_protocol_version = std::min(api_config::get_instance()->use_protocol_version(),conn_.type_info().version());
if (proposed_protocol_version >= 110) {
server_stream << "LSL:streamfeed/" << proposed_protocol_version << " " << conn_.current_uid() << "\r\n";
server_stream << "Native-Byte-Order: " << BOOST_BYTE_ORDER << "\r\n";
server_stream << "Endian-Performance: " << std::floor(measure_endian_performance()) << "\r\n";
server_stream << "Has-IEEE754-Floats: " << (format_ieee754[cft_float32] && format_ieee754[cft_double64]) << "\r\n";
server_stream << "Supports-Subnormals: " << format_subnormal[conn_.type_info().channel_format()] << "\r\n";
server_stream << "Value-Size: " << conn_.type_info().channel_bytes() << "\r\n"; server_stream << "Data-Protocol-Version: " << proposed_protocol_version << "\r\n";
server_stream << "Max-Buffer-Length: " << max_buflen_ << "\r\n";
server_stream << "Max-Chunk-Length: " << max_chunklen_ << "\r\n";
server_stream << "Hostname: " << conn_.type_info().hostname() << "\r\n";
server_stream << "Source-Id: " << conn_.type_info().source_id() << "\r\n";
server_stream << "Session-Id: " << conn_.type_info().session_id() << "\r\n";
server_stream << "\r\n" << std::flush;
char buf[16384] = {0};
if (!server_stream.getline(buf,sizeof(buf)))
throw lost_error("Connection lost.");
std::vector<std::string> parts; split(parts,buf,is_any_of(" \t"));
if (parts.size() < 3 || !starts_with(parts[0],"LSL/"))
throw std::runtime_error("Received a malformed response.");
if (from_string<int>(parts[0].substr(4))/100 > api_config::get_instance()->use_protocol_version()/100)
throw std::runtime_error("The other party's protocol version is too new for this client; please upgrade your LSL library.");
int status_code = from_string<int>(parts[1]);
if (status_code == 404)
throw lost_error("The given address does not serve the resolved stream (likely outdated).");
if (status_code >= 400)
throw std::runtime_error("The other party sent an error: " + std::string(buf));
if (status_code >= 300)
throw lost_error("The other party requested a redirect.");
while (server_stream.getline(buf,sizeof(buf)) && (buf[0] != '\r')) {
std::string hdrline(buf);
std::size_t colon = hdrline.find_first_of(':');
if (colon != std::string::npos) {
std::string type = to_lower_copy(trim_copy(hdrline.substr(0,colon))), rest = to_lower_copy(trim_copy(hdrline.substr(colon+1)));
std::size_t semicolon = rest.find_first_of(';');
if (semicolon != std::string::npos)
rest = rest.substr(0,semicolon);
if (type == "byte-order") {
use_byte_order = from_string<int>(rest);
if (use_byte_order==2134 && BOOST_BYTE_ORDER!=2134 && format_sizes[conn_.type_info().channel_format()]>=8)
throw std::runtime_error("The byte order conversion requested by the other party is not supported.");
}
if (type == "suppress-subnormals")
suppress_subnormals = from_string<bool>(rest);
if (type == "uid" && rest != conn_.current_uid())
throw lost_error("The received UID does not match the current connection's UID.");
if (type == "data-protocol-version") {
data_protocol_version = from_string<int>(rest);
if (data_protocol_version > api_config::get_instance()->use_protocol_version())
throw std::runtime_error("The protocol version requested by the other party is not supported by this client.");
}
}
}
if (!server_stream)
throw lost_error("Server connection lost.");
} else {
server_stream << "LSL:streamfeed\r\n";
server_stream << max_buflen_ << " " << max_chunklen_ << "\r\n" << std::flush;
}
if (data_protocol_version == 100) {
inarch.reset(new eos::portable_iarchive(server_stream));
std::string infomsg; *inarch >> infomsg;
stream_info_impl info; info.from_shortinfo_message(infomsg);
if (info.uid() != conn_.current_uid())
throw lost_error("The received UID does not match the current connection's UID.");
}
{
lslboost::scoped_ptr<sample> temp[4];
for (int k=0; k<4; temp[k++].reset(factory::new_sample_unmanaged(conn_.type_info().channel_format(),conn_.type_info().channel_count(),0.0,false)));
temp[0]->assign_test_pattern(4);
if (data_protocol_version >= 110)
temp[1]->load_streambuf(buffer, data_protocol_version, use_byte_order, suppress_subnormals);
else
*inarch >> *temp[1];
temp[2]->assign_test_pattern(2);
if (data_protocol_version >= 110)
temp[3]->load_streambuf(buffer, data_protocol_version, use_byte_order, suppress_subnormals);
else
*inarch >> *temp[3];
if (!(*temp[0].get() == *temp[1].get()) || !(*temp[2].get() == *temp[3].get()))
throw std::runtime_error("The received test-pattern samples do not match the specification. The protocol formats are likely incompatible.");
}
{
lslboost::lock_guard<lslboost::mutex> lock(connected_mut_);
connected_ = true;
}
connected_upd_.notify_all();
double last_timestamp = 0.0;
double srate = conn_.current_srate();
for (int k=0;!conn_.lost() && !conn_.shutdown() && !closing_stream_;k++) {
sample_p samp(factory->new_sample(0.0,false));
if (data_protocol_version >= 110) samp->load_streambuf(buffer,data_protocol_version,use_byte_order,suppress_subnormals); else *inarch >> *samp;
if (samp->timestamp == DEDUCED_TIMESTAMP) {
samp->timestamp = last_timestamp;
if (srate != IRREGULAR_RATE)
samp->timestamp += 1.0/srate;
}
last_timestamp = samp->timestamp;
sample_queue_.push_sample(samp);
if (srate<=16 || (k & 0xF) == 0)
conn_.update_receive_time(lsl_clock());
}
}
catch(error_code &) {
conn_.try_recover_from_error();
}
catch(lost_error &) {
conn_.try_recover_from_error();
}
catch(shutdown_error &) {
throw lost_error("The inlet has been disengaged.");
}
catch(std::exception &e) {
if (!conn_.shutdown())
std::cerr << "Stream transmission broke off (" << e.what() << "); re-connecting..." << std::endl;
conn_.try_recover_from_error();
}
lslboost::this_thread::sleep_for(lslboost::chrono::milliseconds(500));
}
}
catch(lost_error &) {
sample_queue_.push_sample(sample_p());
}
conn_.release_watchdog();
}
}