#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include "../rssn-advanced.h"
#define CHECK(call, msg) \
do { \
enum RssnStatus _s = (call); \
if (_s != RssnStatusSuccess) { \
fprintf(stderr, "FAIL [status=%d]: %s\n", _s, msg); \
exit(1); \
} \
} while (0)
static double relu_impl(double x) { return x > 0.0 ? x : 0.0; }
int main(void)
{
struct DagBuilder *dag = rssn_dag_new();
if (!dag) { fputs("rssn_dag_new failed\n", stderr); return 1; }
uint32_t x_id, two_id, one_id;
uint32_t x2_id, two_x_id, sum1_id, root_id;
CHECK(rssn_dag_variable_v2(dag, "x", &x_id), "variable x");
CHECK(rssn_dag_constant_v2(dag, 2.0, &two_id), "constant 2");
CHECK(rssn_dag_constant_v2(dag, 1.0, &one_id), "constant 1");
CHECK(rssn_dag_pow_v2 (dag, x_id, two_id, &x2_id), "x^2");
CHECK(rssn_dag_mul_v2 (dag, two_id, x_id, &two_x_id), "2*x");
CHECK(rssn_dag_add_v2 (dag, x2_id, two_x_id, &sum1_id), "x^2+2x");
CHECK(rssn_dag_add_v2 (dag, sum1_id, one_id, &root_id), "x^2+2x+1");
printf("Expression built. Root node id = %u\n", root_id);
void *fn_ptr = NULL;
CHECK(rssn_dag_compile_v2(dag, root_id, &fn_ptr), "compile x^2+2x+1");
printf("\nx^2 + 2x + 1 (should equal (x+1)^2):\n");
double xs[] = { -2.0, -1.0, 0.0, 1.0, 2.0, 3.0 };
for (int i = 0; i < 6; i++) {
double result = 0.0;
double vars[] = { xs[i] };
CHECK(rssn_dag_execute_v2(fn_ptr, vars, &result), "execute");
printf(" f(%.1f) = %.1f expected %.1f%s\n",
xs[i], result, (xs[i] + 1.0) * (xs[i] + 1.0),
result == (xs[i]+1.0)*(xs[i]+1.0) ? " ✓" : " ✗");
}
printf("\nCustom operator relu(x + 3):\n");
struct RssnCustomOpRegistry *reg = rssn_custom_op_registry_new();
if (!reg) { fputs("registry_new failed\n", stderr); return 1; }
struct DagBuilder *dag2 = rssn_dag_new();
uint32_t x2, three_id, xp3_id, relu_fn_id, relu_node;
CHECK(rssn_dag_variable_v2(dag2, "x", &x2), "variable x");
CHECK(rssn_dag_constant_v2(dag2, 3.0, &three_id), "constant 3");
CHECK(rssn_dag_add_v2 (dag2, x2, three_id, &xp3_id), "x+3");
relu_fn_id = rssn_dag_intern_function(dag2, "relu");
if (relu_fn_id == (uint32_t)-1) {
fputs("rssn_dag_intern_function failed\n", stderr); return 1;
}
CHECK(rssn_custom_op_register_fn1(reg, relu_fn_id, "relu",
(double(*)(double))relu_impl,
1),
"register relu");
CHECK(rssn_dag_call_fn_v2 (dag2, relu_fn_id, &xp3_id, 1, &relu_node),
"relu(x+3)");
void *relu_fn = NULL;
CHECK(rssn_dag_compile_with_custom_ops(dag2, relu_node, reg, &relu_fn),
"compile relu(x+3)");
double test_xs[] = { -5.0, -3.0, -1.0, 0.0, 2.0 };
for (int i = 0; i < 5; i++) {
double result = 0.0, expected;
double v[] = { test_xs[i] };
CHECK(rssn_dag_execute_v2(relu_fn, v, &result), "execute relu");
expected = (test_xs[i] + 3.0) > 0.0 ? (test_xs[i] + 3.0) : 0.0;
printf(" relu(%.1f + 3) = %.1f expected %.1f%s\n",
test_xs[i], result, expected,
result == expected ? " ✓" : " ✗");
}
rssn_custom_op_registry_free(reg);
rssn_dag_free(dag2);
rssn_dag_free(dag);
puts("\nDone.");
return 0;
}