#if __INTEL_COMPILER && _MSC_VER
#pragma warning(disable : 2586)
#endif
#include "common/config.h"
#include "tbb/flow_graph.h"
#include "common/test.h"
#include "common/utils.h"
#include "common/utils_assert.h"
#include "common/test_follows_and_precedes_api.h"
#include <atomic>
const int L = 10;
const int N = 1000;
using tbb::detail::d1::SUCCESSFULLY_ENQUEUED;
using tbb::detail::d1::graph_task;
template< typename T >
struct serial_receiver : public tbb::flow::receiver<T>, utils::NoAssign {
T next_value;
tbb::flow::graph& my_graph;
serial_receiver(tbb::flow::graph& g) : next_value(T(0)), my_graph(g) {}
graph_task* try_put_task( const T &v ) override {
CHECK_MESSAGE( next_value++ == v, "" );
return const_cast<graph_task*>(SUCCESSFULLY_ENQUEUED);
}
tbb::flow::graph& graph_reference() const override {
return my_graph;
}
};
template< typename T >
struct parallel_receiver : public tbb::flow::receiver<T>, utils::NoAssign {
std::atomic<int> my_count;
tbb::flow::graph& my_graph;
parallel_receiver(tbb::flow::graph& g) : my_graph(g) { my_count = 0; }
graph_task* try_put_task( const T & ) override {
++my_count;
return const_cast<graph_task*>(SUCCESSFULLY_ENQUEUED);
}
tbb::flow::graph& graph_reference() const override {
return my_graph;
}
};
template< typename T >
struct empty_sender : public tbb::flow::sender<T> {
typedef typename tbb::flow::sender<T>::successor_type successor_type;
bool register_successor( successor_type & ) override { return false; }
bool remove_successor( successor_type & ) override { return false; }
};
template< typename T >
struct put_body : utils::NoAssign {
tbb::flow::limiter_node<T> &my_lim;
std::atomic<int> &my_accept_count;
put_body( tbb::flow::limiter_node<T> &lim, std::atomic<int> &accept_count ) :
my_lim(lim), my_accept_count(accept_count) {}
void operator()( int ) const {
for ( int i = 0; i < L; ++i ) {
bool msg = my_lim.try_put( T(i) );
if ( msg == true )
++my_accept_count;
}
}
};
template< typename T >
struct put_dec_body : utils::NoAssign {
tbb::flow::limiter_node<T> &my_lim;
std::atomic<int> &my_accept_count;
put_dec_body( tbb::flow::limiter_node<T> &lim, std::atomic<int> &accept_count ) :
my_lim(lim), my_accept_count(accept_count) {}
void operator()( int ) const {
int local_accept_count = 0;
while ( local_accept_count < N ) {
bool msg = my_lim.try_put( T(local_accept_count) );
if ( msg == true ) {
++local_accept_count;
++my_accept_count;
my_lim.decrementer().try_put( tbb::flow::continue_msg() );
}
}
}
};
template< typename Sender, typename Receiver >
void make_edge_impl(Sender& sender, Receiver& receiver){
#if __GNUC__ < 12 && !TBB_USE_DEBUG
auto make_edge_ptr = tbb::flow::make_edge<int>;
make_edge_ptr(sender, receiver);
#else
tbb::flow::make_edge(sender, receiver);
#endif
}
template< typename T >
void test_puts_with_decrements( int num_threads, tbb::flow::limiter_node< T >& lim , tbb::flow::graph& g) {
parallel_receiver<T> r(g);
empty_sender< tbb::flow::continue_msg > s;
std::atomic<int> accept_count;
accept_count = 0;
tbb::flow::make_edge( lim, r );
tbb::flow::make_edge(s, lim.decrementer());
utils::NativeParallelFor( num_threads, put_dec_body<T>(lim, accept_count) );
int c = accept_count;
CHECK_MESSAGE( c == N*num_threads, "" );
CHECK_MESSAGE( r.my_count == N*num_threads, "" );
}
template< typename T >
int test_parallel(int num_threads) {
for ( int i = 0; i < L; ++i ) {
tbb::flow::graph g;
tbb::flow::limiter_node< T > lim(g, i);
parallel_receiver<T> r(g);
std::atomic<int> accept_count;
accept_count = 0;
tbb::flow::make_edge( lim, r );
utils::NativeParallelFor( num_threads, put_body<T>(lim, accept_count) );
g.wait_for_all();
int c = accept_count;
CHECK_MESSAGE( c == i, "" );
}
for ( int i = 1; i < L; ++i ) {
tbb::flow::graph g;
tbb::flow::limiter_node< T > lim(g, i);
test_puts_with_decrements(num_threads, lim, g);
tbb::flow::limiter_node< T > lim_copy( lim );
test_puts_with_decrements(num_threads, lim_copy, g);
}
return 0;
}
template< typename T >
int test_serial() {
for ( int i = 0; i < L; ++i ) {
tbb::flow::graph g;
tbb::flow::limiter_node< T > lim(g, i);
serial_receiver<T> r(g);
tbb::flow::make_edge( lim, r );
for ( int j = 0; j < L; ++j ) {
bool msg = lim.try_put( T(j) );
CHECK_MESSAGE( (( j < i && msg == true ) || ( j >= i && msg == false )), "" );
}
g.wait_for_all();
}
for ( int i = 1; i < L; ++i ) {
tbb::flow::graph g;
tbb::flow::limiter_node< T > lim(g, i);
serial_receiver<T> r(g);
empty_sender< tbb::flow::continue_msg > s;
tbb::flow::make_edge( lim, r );
tbb::flow::make_edge(s, lim.decrementer());
for ( int j = 0; j < N; ++j ) {
bool msg = lim.try_put( T(j) );
CHECK_MESSAGE( (( j < i && msg == true ) || ( j >= i && msg == false )), "" );
if ( msg == false ) {
lim.decrementer().try_put( tbb::flow::continue_msg() );
msg = lim.try_put( T(j) );
CHECK_MESSAGE( msg == true, "" );
}
}
}
return 0;
}
#define DECREMENT_OUTPUT 1
#define LIMITER_OUTPUT 0
typedef tbb::flow::multifunction_node<int, std::tuple<int,tbb::flow::continue_msg> > mfnode_type;
std::atomic<size_t> emit_count;
std::atomic<size_t> emit_sum;
std::atomic<size_t> receive_count;
std::atomic<size_t> receive_sum;
struct mfnode_body {
int max_cnt;
std::atomic<int>* my_cnt;
mfnode_body(const int& _max, std::atomic<int> &_my) : max_cnt(_max), my_cnt(&_my) { }
void operator()(const int &, mfnode_type::output_ports_type &out) {
int lcnt = ++(*my_cnt);
if(lcnt > max_cnt) {
return;
}
if(!std::get<DECREMENT_OUTPUT>(out).try_put(tbb::flow::continue_msg())) {
CHECK_MESSAGE( (false),"Unexpected rejection of decrement");
}
{
while( std::get<LIMITER_OUTPUT>(out).try_put(lcnt) ) {
emit_sum += lcnt;
++emit_count;
}
}
}
};
struct fn_body {
int operator()(const int &in) {
receive_sum += in;
++receive_count;
return in;
}
};
void
test_multifunction_to_limiter(int _max, int _nparallel) {
tbb::flow::graph g;
emit_count = 0;
emit_sum = 0;
receive_count = 0;
receive_sum = 0;
std::atomic<int> local_cnt;
local_cnt = 0;
mfnode_type mf_node(g, tbb::flow::unlimited, mfnode_body(_max, local_cnt));
tbb::flow::function_node<int, int> fn_node(g, tbb::flow::unlimited, fn_body());
tbb::flow::limiter_node<int> lim_node(g, _nparallel);
tbb::flow::make_edge(tbb::flow::output_port<LIMITER_OUTPUT>(mf_node), lim_node);
tbb::flow::make_edge(tbb::flow::output_port<DECREMENT_OUTPUT>(mf_node), lim_node.decrementer());
tbb::flow::make_edge(lim_node, fn_node);
tbb::flow::make_edge(fn_node, mf_node);
mf_node.try_put(1);
g.wait_for_all();
CHECK_MESSAGE( (emit_count == receive_count), "counts do not match");
CHECK_MESSAGE( (emit_sum == receive_sum), "sums do not match");
g.reset();
emit_count = 0;
emit_sum = 0;
receive_count = 0;
receive_sum = 0;
local_cnt = 0;;
mf_node.try_put(1);
g.wait_for_all();
CHECK_MESSAGE( (emit_count == receive_count), "counts do not match");
CHECK_MESSAGE( (emit_sum == receive_sum), "sums do not match");
}
void
test_continue_msg_reception() {
tbb::flow::graph g;
tbb::flow::limiter_node<int> ln(g,2);
tbb::flow::queue_node<int> qn(g);
tbb::flow::make_edge(ln, qn);
ln.decrementer().try_put(tbb::flow::continue_msg());
ln.try_put(42);
g.wait_for_all();
int outint;
CHECK_MESSAGE( (qn.try_get(outint) && outint == 42), "initial put to decrement stops node");
}
void test_reserve_release_messages() {
using namespace tbb::flow;
graph g;
queue_node<int> input_queue(g);
queue_node<int> output_queue(g);
broadcast_node<int> broad(g);
limiter_node<int, int> limit(g,2);
make_edge(input_queue, limit);
make_edge(limit, output_queue);
make_edge(broad,limit.decrementer());
int list[4] = {19, 33, 72, 98};
input_queue.try_put(list[0]); input_queue.try_put(list[1]); input_queue.try_put(list[2]); g.wait_for_all();
remove_edge(limit, output_queue);
broad.try_put(1); g.wait_for_all();
make_edge_impl(limit, output_queue);
broad.try_put(1);
input_queue.try_put(list[3]); g.wait_for_all();
int var=0;
for (int i=0; i<4; i++) {
output_queue.try_get(var);
CHECK_MESSAGE( (var==list[i]), "some data dropped, input does not match output");
g.wait_for_all();
}
}
void test_decrementer() {
const int threshold = 5;
tbb::flow::graph g;
tbb::flow::limiter_node<int, int> limit(g, threshold);
tbb::flow::queue_node<int> queue(g);
make_edge(limit, queue);
int m = 0;
CHECK_MESSAGE( ( limit.try_put( m++ )), "Newly constructed limiter node does not accept message." );
CHECK_MESSAGE( limit.decrementer().try_put( -threshold ), "Limiter node decrementer's port does not accept message." );
CHECK_MESSAGE( ( !limit.try_put( m++ )), "Closed limiter node's accepts message." );
CHECK_MESSAGE( limit.decrementer().try_put( threshold + 5 ), "Limiter node decrementer's port does not accept message." );
for( int i = 0; i < threshold; ++i )
CHECK_MESSAGE( ( limit.try_put( m++ )), "Limiter node does not accept message while open." );
CHECK_MESSAGE( ( !limit.try_put( m )), "Limiter node's gate is not closed." );
g.wait_for_all();
int expected[] = {0, 2, 3, 4, 5, 6};
int actual = -1; m = 0;
while( queue.try_get(actual) )
CHECK_MESSAGE( actual == expected[m++], "" );
CHECK_MESSAGE( ( sizeof(expected) / sizeof(expected[0]) == m), "Not all messages have been processed." );
g.wait_for_all();
const size_t threshold2 = size_t(-1);
tbb::flow::limiter_node<int, long long> limit2(g, threshold2);
make_edge(limit2, queue);
CHECK_MESSAGE( ( limit2.try_put( 1 )), "Newly constructed limiter node does not accept message." );
long long decrement_value = (long long)( size_t(-1)/2 );
CHECK_MESSAGE( limit2.decrementer().try_put( -decrement_value ),
"Limiter node decrementer's port does not accept message" );
CHECK_MESSAGE( ( limit2.try_put( 2 )), "Limiter's gate should not be closed yet." );
CHECK_MESSAGE( limit2.decrementer().try_put( -decrement_value ),
"Limiter node decrementer's port does not accept message" );
CHECK_MESSAGE( ( !limit2.try_put( 3 )), "Overflow happened for internal counter." );
int expected2[] = {1, 2};
actual = -1; m = 0;
while( queue.try_get(actual) )
CHECK_MESSAGE( actual == expected2[m++], "" );
CHECK_MESSAGE( ( sizeof(expected2) / sizeof(expected2[0]) == m), "Not all messages have been processed." );
g.wait_for_all();
const size_t threshold3 = 10;
tbb::flow::limiter_node<int, long long> limit3(g, threshold3);
make_edge(limit3, queue);
long long decrement_value3 = 3;
CHECK_MESSAGE( limit3.decrementer().try_put( -decrement_value3 ),
"Limiter node decrementer's port does not accept message" );
m = 0;
while( limit3.try_put( m ) ){ m++; };
CHECK_MESSAGE( m == threshold3 - decrement_value3, "Not all messages have been accepted." );
actual = -1; m = 0;
while( queue.try_get(actual) ){
CHECK_MESSAGE( actual == m++, "Not all messages have been processed." );
}
g.wait_for_all();
CHECK_MESSAGE( m == threshold3 - decrement_value3, "Not all messages have been processed." );
}
void test_try_put_without_successors() {
tbb::flow::graph g;
int try_put_num{3};
tbb::flow::buffer_node<int> bn(g);
tbb::flow::limiter_node<int> ln(g, try_put_num);
tbb::flow::make_edge(bn, ln);
int i = 1;
for (; i <= try_put_num; i++)
bn.try_put(i);
std::atomic<int> counter{0};
tbb::flow::function_node<int, int> fn(g, tbb::flow::unlimited,
[&](int input) {
counter += input;
return int{};
}
);
make_edge_impl(ln, fn);
g.wait_for_all();
CHECK((counter == i * try_put_num / 2));
tbb::flow::remove_edge(bn, ln);
ln.decrementer().try_put(tbb::flow::continue_msg());
bn.try_put(try_put_num + 1);
g.wait_for_all();
CHECK((counter == i * try_put_num / 2));
}
#if __TBB_PREVIEW_FLOW_GRAPH_NODE_SET
#include <array>
#include <vector>
void test_follows_and_precedes_api() {
using msg_t = tbb::flow::continue_msg;
std::array<msg_t, 3> messages_for_follows= { {msg_t(), msg_t(), msg_t()} };
std::vector<msg_t> messages_for_precedes = {msg_t()};
follows_and_precedes_testing::test_follows
<msg_t, tbb::flow::limiter_node<msg_t, msg_t>>(messages_for_follows, 1000);
follows_and_precedes_testing::test_precedes
<msg_t, tbb::flow::limiter_node<msg_t, msg_t>>(messages_for_precedes, 1000);
}
#endif
#if __TBB_CPP17_DEDUCTION_GUIDES_PRESENT
void test_deduction_guides() {
using namespace tbb::flow;
graph g;
broadcast_node<int> br(g);
limiter_node<int> l0(g, 100);
#if __TBB_PREVIEW_FLOW_GRAPH_NODE_SET
limiter_node l1(follows(br), 100);
static_assert(std::is_same_v<decltype(l1), limiter_node<int>>);
limiter_node l2(precedes(br), 100);
static_assert(std::is_same_v<decltype(l2), limiter_node<int>>);
#endif
limiter_node l3(l0);
static_assert(std::is_same_v<decltype(l3), limiter_node<int>>);
}
#endif
TEST_CASE("Serial and parallel tests") {
for (unsigned i = 1; i <= 2 * utils::MaxThread; ++i) {
tbb::task_arena arena(i);
arena.execute(
[i]() {
test_serial<int>();
test_parallel<int>(i);
}
);
}
}
TEST_CASE("Test continue_msg reception") {
test_continue_msg_reception();
}
TEST_CASE("Multifunction connected to limiter") {
test_multifunction_to_limiter(30,3);
test_multifunction_to_limiter(300,13);
test_multifunction_to_limiter(3000,1);
}
TEST_CASE("Message is released if successor does not accept") {
test_reserve_release_messages();
}
TEST_CASE("Decrementer") {
test_decrementer();
}
TEST_CASE("Test try_put() without successors") {
test_try_put_without_successors();
}
#if __TBB_PREVIEW_FLOW_GRAPH_NODE_SET
TEST_CASE( "Support for follows and precedes API" ) {
test_follows_and_precedes_api();
}
#endif
#if __TBB_CPP17_DEDUCTION_GUIDES_PRESENT
TEST_CASE( "Deduction guides" ) {
test_deduction_guides();
}
#endif