#ifndef OS_WINDOWS
#include "common_utest.h"
#include <sys/wait.h>
#include <cblas.h>
void* xmalloc(size_t n)
{
void* tmp;
tmp = malloc(n);
if (tmp == NULL) {
fprintf(stderr, "You are about to die\n");
exit(1);
} else {
return tmp;
}
}
void check_dgemm(double *a, double *b, double *result, double *expected, int n)
{
int i;
cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, n, n, n,
1.0, a, n, b, n, 0.0, result, n);
for(i = 0; i < n * n; ++i) {
CU_ASSERT_DOUBLE_EQUAL(expected[i], result[i], CHECK_EPS);
}
}
void test_fork_safety(void)
{
int n = 1000;
int i;
double *a, *b, *c, *d;
size_t n_bytes;
pid_t fork_pid;
pid_t fork_pid_nested;
n_bytes = sizeof(*a) * n * n;
a = xmalloc(n_bytes);
b = xmalloc(n_bytes);
c = xmalloc(n_bytes);
d = xmalloc(n_bytes);
for(i = 0; i < n * n; ++i) {
a[i] = 1;
b[i] = 1;
}
cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, n, n, n,
1.0, a, n, b, n, 0.0, c, n);
fork_pid = fork();
if (fork_pid == -1) {
CU_FAIL("Failed to fork process.");
} else if (fork_pid == 0) {
check_dgemm(a, b, d, c, n);
fork_pid_nested = fork();
if (fork_pid_nested == -1) {
CU_FAIL("Failed to fork process.");
exit(1);
} else if (fork_pid_nested == 0) {
check_dgemm(a, b, d, c, n);
exit(0);
} else {
check_dgemm(a, b, d, c, n);
int child_status = 0;
pid_t wait_pid = wait(&child_status);
CU_ASSERT(wait_pid == fork_pid_nested);
CU_ASSERT(WEXITSTATUS (child_status) == 0);
exit(0);
}
} else {
check_dgemm(a, b, d, c, n);
int child_status = 0;
pid_t wait_pid = wait(&child_status);
CU_ASSERT(wait_pid == fork_pid);
CU_ASSERT(WEXITSTATUS (child_status) == 0);
}
}
#endif