use std::fmt::Write;
#[must_use]
pub fn generate_taylor_wgsl(k: usize) -> String {
assert!((1..=5).contains(&k), "K must be in 1..=5, got {k}");
let mut s = String::with_capacity(16384);
writeln!(s, "// Auto-generated Taylor forward K={k} kernel.").unwrap();
writeln!(s, "// Do not edit — generated by taylor_codegen.rs.\n").unwrap();
write_wgsl_opcodes(&mut s);
write_wgsl_bindings(&mut s, k);
write_wgsl_helpers(&mut s);
write_wgsl_jet_type(&mut s, k);
write_wgsl_jet_arithmetic(&mut s, k);
write_wgsl_jet_transcendental(&mut s, k);
write_wgsl_jet_inverse_trig(&mut s, k);
write_wgsl_main_kernel(&mut s, k);
s
}
#[must_use]
pub fn generate_taylor_cuda(k: usize) -> String {
assert!((1..=5).contains(&k), "K must be in 1..=5, got {k}");
let mut s = String::with_capacity(16384);
writeln!(s, "// Auto-generated Taylor forward K={k} kernel.").unwrap();
writeln!(s, "// Do not edit — generated by taylor_codegen.rs.\n").unwrap();
writeln!(s, "typedef FLOAT_TYPE F;").unwrap();
write_cuda_opcodes(&mut s);
write_cuda_helpers(&mut s);
write_cuda_jet_type(&mut s, k);
write_cuda_jet_arithmetic(&mut s, k);
write_cuda_jet_transcendental(&mut s, k);
write_cuda_jet_inverse_trig(&mut s, k);
write_cuda_main_kernel(&mut s, k);
s
}
fn write_wgsl_opcodes(s: &mut String) {
let ops = [
("OP_INPUT", 0),
("OP_CONST", 1),
("OP_ADD", 2),
("OP_SUB", 3),
("OP_MUL", 4),
("OP_DIV", 5),
("OP_REM", 6),
("OP_POWF", 7),
("OP_ATAN2", 8),
("OP_HYPOT", 9),
("OP_MAX", 10),
("OP_MIN", 11),
("OP_NEG", 12),
("OP_RECIP", 13),
("OP_SQRT", 14),
("OP_CBRT", 15),
("OP_POWI", 16),
("OP_EXP", 17),
("OP_EXP2", 18),
("OP_EXPM1", 19),
("OP_LN", 20),
("OP_LOG2", 21),
("OP_LOG10", 22),
("OP_LN1P", 23),
("OP_SIN", 24),
("OP_COS", 25),
("OP_TAN", 26),
("OP_ASIN", 27),
("OP_ACOS", 28),
("OP_ATAN", 29),
("OP_SINH", 30),
("OP_COSH", 31),
("OP_TANH", 32),
("OP_ASINH", 33),
("OP_ACOSH", 34),
("OP_ATANH", 35),
("OP_ABS", 36),
("OP_SIGNUM", 37),
("OP_FLOOR", 38),
("OP_CEIL", 39),
("OP_ROUND", 40),
("OP_TRUNC", 41),
("OP_FRACT", 42),
];
for (name, val) in &ops {
writeln!(s, "const {name}: u32 = {val}u;").unwrap();
}
writeln!(s).unwrap();
}
fn write_wgsl_bindings(s: &mut String, k: usize) {
writeln!(
s,
"struct TapeMeta {{
num_ops: u32,
num_inputs: u32,
num_variables: u32,
num_outputs: u32,
batch_size: u32,
_pad0: u32,
_pad1: u32,
_pad2: u32,
}}
@group(0) @binding(0) var<storage, read> opcodes: array<u32>;
@group(0) @binding(1) var<storage, read> arg0: array<u32>;
@group(0) @binding(2) var<storage, read> arg1: array<u32>;
@group(0) @binding(3) var<storage, read> constants: array<f32>;
@group(0) @binding(4) var<uniform> tape_meta: TapeMeta;
@group(0) @binding(5) var<storage, read> output_indices: array<u32>;
@group(1) @binding(0) var<storage, read> primal_inputs: array<f32>;
@group(1) @binding(1) var<storage, read> direction_seeds: array<f32>;
@group(1) @binding(2) var<storage, read_write> jets: array<f32>;
@group(1) @binding(3) var<storage, read_write> jet_outputs: array<f32>;
"
)
.unwrap();
let _ = k; }
fn write_wgsl_helpers(s: &mut String) {
writeln!(
s,
"fn sinh_f(x: f32) -> f32 {{ return (exp(x) - exp(-x)) * 0.5; }}
fn cosh_f(x: f32) -> f32 {{ return (exp(x) + exp(-x)) * 0.5; }}
fn asinh_f(x: f32) -> f32 {{ return log(x + sqrt(x * x + 1.0)); }}
fn acosh_f(x: f32) -> f32 {{ return log(x + sqrt((x - 1.0) * (x + 1.0))); }}
fn atanh_f(x: f32) -> f32 {{ return 0.5 * log((1.0 + x) / (1.0 - x)); }}
"
)
.unwrap();
}
fn write_wgsl_jet_type(s: &mut String, k: usize) {
writeln!(s, "struct JetK {{ v: array<f32, {k}>, }}\n").unwrap();
write!(
s,
"fn jet_const(val: f32) -> JetK {{\n var j: JetK;\n j.v[0] = val;\n"
)
.unwrap();
for i in 1..k {
writeln!(s, " j.v[{i}] = 0.0;").unwrap();
}
writeln!(s, " return j;\n}}\n").unwrap();
writeln!(s, "fn jet_load(base: u32) -> JetK {{").unwrap();
writeln!(s, " var j: JetK;").unwrap();
for i in 0..k {
writeln!(s, " j.v[{i}] = jets[base + {i}u];").unwrap();
}
writeln!(s, " return j;\n}}\n").unwrap();
writeln!(s, "fn jet_store(base: u32, j: JetK) {{").unwrap();
for i in 0..k {
writeln!(s, " jets[base + {i}u] = j.v[{i}];").unwrap();
}
writeln!(s, "}}\n").unwrap();
}
fn write_wgsl_jet_arithmetic(s: &mut String, k: usize) {
writeln!(s, "fn jet_add(a: JetK, b: JetK) -> JetK {{").unwrap();
writeln!(s, " var c: JetK;").unwrap();
for i in 0..k {
writeln!(s, " c.v[{i}] = a.v[{i}] + b.v[{i}];").unwrap();
}
writeln!(s, " return c;\n}}\n").unwrap();
writeln!(s, "fn jet_sub(a: JetK, b: JetK) -> JetK {{").unwrap();
writeln!(s, " var c: JetK;").unwrap();
for i in 0..k {
writeln!(s, " c.v[{i}] = a.v[{i}] - b.v[{i}];").unwrap();
}
writeln!(s, " return c;\n}}\n").unwrap();
writeln!(s, "fn jet_neg(a: JetK) -> JetK {{").unwrap();
writeln!(s, " var c: JetK;").unwrap();
for i in 0..k {
writeln!(s, " c.v[{i}] = -a.v[{i}];").unwrap();
}
writeln!(s, " return c;\n}}\n").unwrap();
writeln!(s, "fn jet_scale(a: JetK, s: f32) -> JetK {{").unwrap();
writeln!(s, " var c: JetK;").unwrap();
for i in 0..k {
writeln!(s, " c.v[{i}] = a.v[{i}] * s;").unwrap();
}
writeln!(s, " return c;\n}}\n").unwrap();
writeln!(s, "fn jet_mul(a: JetK, b: JetK) -> JetK {{").unwrap();
writeln!(s, " var c: JetK;").unwrap();
for i in 0..k {
let mut terms = Vec::new();
for j in 0..=i {
terms.push(format!("a.v[{j}] * b.v[{}]", i - j));
}
writeln!(s, " c.v[{i}] = {};", terms.join(" + ")).unwrap();
}
writeln!(s, " return c;\n}}\n").unwrap();
writeln!(s, "fn jet_div(a: JetK, b: JetK) -> JetK {{").unwrap();
writeln!(s, " var c: JetK;").unwrap();
writeln!(s, " let inv_b0 = 1.0 / b.v[0];").unwrap();
writeln!(s, " c.v[0] = a.v[0] * inv_b0;").unwrap();
for i in 1..k {
let mut terms = Vec::new();
for j in 1..=i {
terms.push(format!("b.v[{j}] * c.v[{}]", i - j));
}
writeln!(
s,
" c.v[{i}] = (a.v[{i}] - ({})) * inv_b0;",
terms.join(" + ")
)
.unwrap();
}
writeln!(s, " return c;\n}}\n").unwrap();
writeln!(s, "fn jet_recip(a: JetK) -> JetK {{").unwrap();
writeln!(s, " var c: JetK;").unwrap();
writeln!(s, " c.v[0] = 1.0 / a.v[0];").unwrap();
for i in 1..k {
let mut terms = Vec::new();
for j in 1..=i {
terms.push(format!("a.v[{j}] * c.v[{}]", i - j));
}
writeln!(s, " c.v[{i}] = -({}) * c.v[0];", terms.join(" + ")).unwrap();
}
writeln!(s, " return c;\n}}\n").unwrap();
}
fn write_wgsl_jet_transcendental(s: &mut String, k: usize) {
writeln!(s, "fn jet_exp(a: JetK) -> JetK {{").unwrap();
writeln!(s, " var c: JetK;").unwrap();
writeln!(s, " c.v[0] = exp(a.v[0]);").unwrap();
for i in 1..k {
let inv_i = 1.0 / i as f64;
let mut terms = Vec::new();
for j in 1..=i {
terms.push(format!("{:.1} * a.v[{j}] * c.v[{}]", j as f64, i - j));
}
writeln!(s, " c.v[{i}] = {inv_i:.10} * ({});", terms.join(" + ")).unwrap();
}
writeln!(s, " return c;\n}}\n").unwrap();
writeln!(s, "fn jet_ln(a: JetK) -> JetK {{").unwrap();
writeln!(s, " var c: JetK;").unwrap();
writeln!(s, " let inv_a0 = 1.0 / a.v[0];").unwrap();
writeln!(s, " c.v[0] = log(a.v[0]);").unwrap();
for i in 1..k {
if i == 1 {
writeln!(s, " c.v[1] = a.v[1] * inv_a0;").unwrap();
} else {
let inv_i = 1.0 / i as f64;
let mut terms = Vec::new();
for j in 1..i {
terms.push(format!("{:.1} * c.v[{j}] * a.v[{}]", j as f64, i - j));
}
writeln!(
s,
" c.v[{i}] = (a.v[{i}] - {inv_i:.10} * ({})) * inv_a0;",
terms.join(" + ")
)
.unwrap();
}
}
writeln!(s, " return c;\n}}\n").unwrap();
writeln!(s, "fn jet_sqrt(a: JetK) -> JetK {{").unwrap();
writeln!(s, " var c: JetK;").unwrap();
writeln!(s, " c.v[0] = sqrt(a.v[0]);").unwrap();
if k > 1 {
writeln!(s, " let inv_2c0 = 0.5 / c.v[0];").unwrap();
}
for i in 1..k {
if i == 1 {
writeln!(s, " c.v[1] = a.v[1] * inv_2c0;").unwrap();
} else {
let mut terms = Vec::new();
for j in 1..i {
terms.push(format!("c.v[{j}] * c.v[{}]", i - j));
}
writeln!(
s,
" c.v[{i}] = (a.v[{i}] - ({})) * inv_2c0;",
terms.join(" + ")
)
.unwrap();
}
}
writeln!(s, " return c;\n}}\n").unwrap();
writeln!(s, "struct JetPair {{ a: JetK, b: JetK, }}\n").unwrap();
writeln!(s, "fn jet_sin_cos(a: JetK) -> JetPair {{").unwrap();
writeln!(s, " var sn: JetK;").unwrap();
writeln!(s, " var co: JetK;").unwrap();
writeln!(s, " sn.v[0] = sin(a.v[0]);").unwrap();
writeln!(s, " co.v[0] = cos(a.v[0]);").unwrap();
for i in 1..k {
let inv_i = 1.0 / i as f64;
let mut sterms = Vec::new();
let mut cterms = Vec::new();
for j in 1..=i {
sterms.push(format!("{:.1} * a.v[{j}] * co.v[{}]", j as f64, i - j));
cterms.push(format!("{:.1} * a.v[{j}] * sn.v[{}]", j as f64, i - j));
}
writeln!(s, " sn.v[{i}] = {inv_i:.10} * ({});", sterms.join(" + ")).unwrap();
writeln!(
s,
" co.v[{i}] = -{inv_i:.10} * ({});",
cterms.join(" + ")
)
.unwrap();
}
writeln!(s, " return JetPair(sn, co);\n}}\n").unwrap();
writeln!(s, "fn jet_sinh_cosh(a: JetK) -> JetPair {{").unwrap();
writeln!(s, " var sh: JetK;").unwrap();
writeln!(s, " var ch: JetK;").unwrap();
writeln!(s, " sh.v[0] = sinh_f(a.v[0]);").unwrap();
writeln!(s, " ch.v[0] = cosh_f(a.v[0]);").unwrap();
for i in 1..k {
let inv_i = 1.0 / i as f64;
let mut shterms = Vec::new();
let mut chterms = Vec::new();
for j in 1..=i {
shterms.push(format!("{:.1} * a.v[{j}] * ch.v[{}]", j as f64, i - j));
chterms.push(format!("{:.1} * a.v[{j}] * sh.v[{}]", j as f64, i - j));
}
writeln!(
s,
" sh.v[{i}] = {inv_i:.10} * ({});",
shterms.join(" + ")
)
.unwrap();
writeln!(
s,
" ch.v[{i}] = {inv_i:.10} * ({});",
chterms.join(" + ")
)
.unwrap();
}
writeln!(s, " return JetPair(sh, ch);\n}}\n").unwrap();
writeln!(s, "fn jet_tan(a: JetK) -> JetK {{").unwrap();
writeln!(s, " var c: JetK;").unwrap();
writeln!(s, " var sc: JetK; // 1 + c²").unwrap();
writeln!(s, " c.v[0] = tan(a.v[0]);").unwrap();
writeln!(s, " sc.v[0] = 1.0 + c.v[0] * c.v[0];").unwrap();
for i in 1..k {
let inv_i = 1.0 / i as f64;
let mut terms = Vec::new();
for j in 1..=i {
terms.push(format!("{:.1} * a.v[{j}] * sc.v[{}]", j as f64, i - j));
}
writeln!(s, " c.v[{i}] = {inv_i:.10} * ({});", terms.join(" + ")).unwrap();
let mut sc_terms = Vec::new();
for j in 0..=i {
sc_terms.push(format!("c.v[{j}] * c.v[{}]", i - j));
}
writeln!(s, " sc.v[{i}] = {};", sc_terms.join(" + ")).unwrap();
}
writeln!(s, " return c;\n}}\n").unwrap();
writeln!(s, "fn jet_tanh(a: JetK) -> JetK {{").unwrap();
writeln!(s, " var c: JetK;").unwrap();
writeln!(s, " var sc: JetK; // 1 - c²").unwrap();
writeln!(s, " c.v[0] = tanh(a.v[0]);").unwrap();
writeln!(s, " sc.v[0] = 1.0 - c.v[0] * c.v[0];").unwrap();
for i in 1..k {
let inv_i = 1.0 / i as f64;
let mut terms = Vec::new();
for j in 1..=i {
terms.push(format!("{:.1} * a.v[{j}] * sc.v[{}]", j as f64, i - j));
}
writeln!(s, " c.v[{i}] = {inv_i:.10} * ({});", terms.join(" + ")).unwrap();
let mut sc_terms = Vec::new();
for j in 0..=i {
sc_terms.push(format!("c.v[{j}] * c.v[{}]", i - j));
}
writeln!(s, " sc.v[{i}] = -({});", sc_terms.join(" + ")).unwrap();
}
writeln!(s, " return c;\n}}\n").unwrap();
}
fn write_wgsl_jet_inverse_trig(s: &mut String, k: usize) {
writeln!(s, "fn jet_atan(a: JetK) -> JetK {{").unwrap();
writeln!(s, " let asq = jet_mul(a, a);").unwrap();
write!(s, " var d: JetK;\n d.v[0] = 1.0 + asq.v[0];\n").unwrap();
for i in 1..k {
writeln!(s, " d.v[{i}] = asq.v[{i}];").unwrap();
}
writeln!(s, " let g = jet_recip(d);").unwrap();
writeln!(s, " var c: JetK;").unwrap();
writeln!(s, " c.v[0] = atan(a.v[0]);").unwrap();
for i in 1..k {
let inv_i = 1.0 / i as f64;
let mut terms = Vec::new();
for j in 1..=i {
terms.push(format!("{:.1} * a.v[{j}] * g.v[{}]", j as f64, i - j));
}
writeln!(s, " c.v[{i}] = {inv_i:.10} * ({});", terms.join(" + ")).unwrap();
}
writeln!(s, " return c;\n}}\n").unwrap();
writeln!(s, "fn jet_asin(a: JetK) -> JetK {{").unwrap();
writeln!(s, " let asq = jet_mul(a, a);").unwrap();
write!(s, " var d: JetK;\n d.v[0] = 1.0 - asq.v[0];\n").unwrap();
for i in 1..k {
writeln!(s, " d.v[{i}] = -asq.v[{i}];").unwrap();
}
writeln!(s, " let g = jet_recip(jet_sqrt(d));").unwrap();
writeln!(s, " var c: JetK;").unwrap();
writeln!(s, " c.v[0] = asin(a.v[0]);").unwrap();
for i in 1..k {
let inv_i = 1.0 / i as f64;
let mut terms = Vec::new();
for j in 1..=i {
terms.push(format!("{:.1} * a.v[{j}] * g.v[{}]", j as f64, i - j));
}
writeln!(s, " c.v[{i}] = {inv_i:.10} * ({});", terms.join(" + ")).unwrap();
}
writeln!(s, " return c;\n}}\n").unwrap();
writeln!(s, "fn jet_acos(a: JetK) -> JetK {{").unwrap();
writeln!(s, " let asq = jet_mul(a, a);").unwrap();
write!(s, " var d: JetK;\n d.v[0] = 1.0 - asq.v[0];\n").unwrap();
for i in 1..k {
writeln!(s, " d.v[{i}] = -asq.v[{i}];").unwrap();
}
writeln!(s, " let g = jet_recip(jet_sqrt(d));").unwrap();
writeln!(s, " var c: JetK;").unwrap();
writeln!(s, " c.v[0] = acos(a.v[0]);").unwrap();
for i in 1..k {
let inv_i = 1.0 / i as f64;
let mut terms = Vec::new();
for j in 1..=i {
terms.push(format!("{:.1} * a.v[{j}] * g.v[{}]", j as f64, i - j));
}
writeln!(s, " c.v[{i}] = -{inv_i:.10} * ({});", terms.join(" + ")).unwrap();
}
writeln!(s, " return c;\n}}\n").unwrap();
writeln!(s, "fn jet_asinh(a: JetK) -> JetK {{").unwrap();
writeln!(s, " let asq = jet_mul(a, a);").unwrap();
write!(s, " var d: JetK;\n d.v[0] = 1.0 + asq.v[0];\n").unwrap();
for i in 1..k {
writeln!(s, " d.v[{i}] = asq.v[{i}];").unwrap();
}
writeln!(s, " let g = jet_recip(jet_sqrt(d));").unwrap();
writeln!(s, " var c: JetK;").unwrap();
writeln!(s, " c.v[0] = asinh_f(a.v[0]);").unwrap();
for i in 1..k {
let inv_i = 1.0 / i as f64;
let mut terms = Vec::new();
for j in 1..=i {
terms.push(format!("{:.1} * a.v[{j}] * g.v[{}]", j as f64, i - j));
}
writeln!(s, " c.v[{i}] = {inv_i:.10} * ({});", terms.join(" + ")).unwrap();
}
writeln!(s, " return c;\n}}\n").unwrap();
writeln!(s, "fn jet_acosh(a: JetK) -> JetK {{").unwrap();
writeln!(s, " let asq = jet_mul(a, a);").unwrap();
write!(
s,
" var d: JetK;\n d.v[0] = (a.v[0] - 1.0) * (a.v[0] + 1.0);\n"
)
.unwrap();
for i in 1..k {
writeln!(s, " d.v[{i}] = asq.v[{i}];").unwrap();
}
writeln!(s, " let g = jet_recip(jet_sqrt(d));").unwrap();
writeln!(s, " var c: JetK;").unwrap();
writeln!(s, " c.v[0] = acosh_f(a.v[0]);").unwrap();
for i in 1..k {
let inv_i = 1.0 / i as f64;
let mut terms = Vec::new();
for j in 1..=i {
terms.push(format!("{:.1} * a.v[{j}] * g.v[{}]", j as f64, i - j));
}
writeln!(s, " c.v[{i}] = {inv_i:.10} * ({});", terms.join(" + ")).unwrap();
}
writeln!(s, " return c;\n}}\n").unwrap();
writeln!(s, "fn jet_atanh(a: JetK) -> JetK {{").unwrap();
writeln!(s, " let asq = jet_mul(a, a);").unwrap();
write!(s, " var d: JetK;\n d.v[0] = 1.0 - asq.v[0];\n").unwrap();
for i in 1..k {
writeln!(s, " d.v[{i}] = -asq.v[{i}];").unwrap();
}
writeln!(s, " let g = jet_recip(d);").unwrap();
writeln!(s, " var c: JetK;").unwrap();
writeln!(s, " c.v[0] = atanh_f(a.v[0]);").unwrap();
for i in 1..k {
let inv_i = 1.0 / i as f64;
let mut terms = Vec::new();
for j in 1..=i {
terms.push(format!("{:.1} * a.v[{j}] * g.v[{}]", j as f64, i - j));
}
writeln!(s, " c.v[{i}] = {inv_i:.10} * ({});", terms.join(" + ")).unwrap();
}
writeln!(s, " return c;\n}}\n").unwrap();
}
fn write_wgsl_main_kernel(s: &mut String, k: usize) {
writeln!(
s,
"@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{
let bid = gid.x;
if bid >= tape_meta.batch_size {{
return;
}}
let nv = tape_meta.num_variables;
let ni = tape_meta.num_inputs;
let num_ops = tape_meta.num_ops;
let n_out = tape_meta.num_outputs;
let K = {k}u;
let j_base = bid * nv * K;"
)
.unwrap();
writeln!(s, "\n // Initialize from constants").unwrap();
writeln!(s, " for (var i = 0u; i < nv; i = i + 1u) {{").unwrap();
writeln!(s, " let off = j_base + i * K;").unwrap();
writeln!(s, " jets[off] = constants[i];").unwrap();
for c in 1..k {
writeln!(s, " jets[off + {c}u] = 0.0;").unwrap();
}
writeln!(s, " }}").unwrap();
writeln!(s, "\n // Set input jets").unwrap();
writeln!(s, " let in_base = bid * ni;").unwrap();
writeln!(s, " for (var i = 0u; i < ni; i = i + 1u) {{").unwrap();
writeln!(s, " let off = j_base + i * K;").unwrap();
writeln!(s, " jets[off] = primal_inputs[in_base + i];").unwrap();
if k > 1 {
writeln!(s, " jets[off + 1u] = direction_seeds[in_base + i];").unwrap();
}
writeln!(s, " }}").unwrap();
writeln!(s, "\n // Walk the tape").unwrap();
writeln!(s, " for (var i = ni; i < num_ops; i = i + 1u) {{").unwrap();
writeln!(s, " let op = opcodes[i];").unwrap();
writeln!(s, " if op == OP_CONST {{ continue; }}").unwrap();
writeln!(s, " let a_idx = arg0[i];").unwrap();
writeln!(s, " let b_idx = arg1[i];").unwrap();
writeln!(s, " let a = jet_load(j_base + a_idx * K);").unwrap();
writeln!(s, " var r = jet_const(0.0);").unwrap();
writeln!(s, " switch op {{").unwrap();
for (case, name) in &[
(2, "jet_add"),
(3, "jet_sub"),
(4, "jet_mul"),
(5, "jet_div"),
] {
writeln!(s, " case {case}u: {{").unwrap();
writeln!(s, " let b = jet_load(j_base + b_idx * K);").unwrap();
writeln!(s, " r = {name}(a, b);").unwrap();
writeln!(s, " }}").unwrap();
}
writeln!(s, " case 6u: {{").unwrap();
writeln!(s, " let b = jet_load(j_base + b_idx * K);").unwrap();
writeln!(s, " let q = trunc(a.v[0] / b.v[0]);").unwrap();
writeln!(s, " r.v[0] = a.v[0] - q * b.v[0];").unwrap();
for i in 1..k {
writeln!(s, " r.v[{i}] = a.v[{i}] - q * b.v[{i}];").unwrap();
}
writeln!(s, " }}").unwrap();
writeln!(s, " case 7u: {{").unwrap();
writeln!(s, " let b = jet_load(j_base + b_idx * K);").unwrap();
writeln!(s, " if a.v[0] <= 0.0 {{").unwrap();
writeln!(s, " let val = pow(a.v[0], b.v[0]);").unwrap();
writeln!(
s,
" let da = b.v[0] * pow(a.v[0], b.v[0] - 1.0);"
)
.unwrap();
writeln!(s, " let db = 0.0;").unwrap();
writeln!(s, " r.v[0] = val;").unwrap();
if k > 1 {
writeln!(s, " r.v[1] = da * a.v[1] + db * b.v[1];").unwrap();
}
for i in 2..k {
writeln!(
s,
" r.v[{i}] = bitcast<f32>(0x7fc00000u);"
)
.unwrap();
}
writeln!(s, " }} else {{").unwrap();
writeln!(s, " let lna = jet_ln(a);").unwrap();
writeln!(s, " let product = jet_mul(b, lna);").unwrap();
writeln!(s, " r = jet_exp(product);").unwrap();
writeln!(s, " r.v[0] = pow(a.v[0], b.v[0]);").unwrap();
writeln!(s, " }}").unwrap();
writeln!(s, " }}").unwrap();
writeln!(s, " case 8u: {{").unwrap();
writeln!(s, " let b = jet_load(j_base + b_idx * K);").unwrap();
writeln!(s, " if b.v[0] == 0.0 {{").unwrap();
writeln!(s, " r.v[0] = atan2(a.v[0], b.v[0]);").unwrap();
writeln!(s, " if a.v[0] == 0.0 {{").unwrap();
for i in 1..k {
writeln!(s, " r.v[{i}] = 0.0;").unwrap();
}
writeln!(s, " }} else {{").unwrap();
writeln!(s, " let ratio = jet_div(b, a);").unwrap();
writeln!(s, " let c = jet_atan(ratio);").unwrap();
for i in 1..k {
writeln!(s, " r.v[{i}] = -c.v[{i}];").unwrap();
}
writeln!(s, " }}").unwrap();
writeln!(s, " }} else {{").unwrap();
writeln!(s, " let ratio = jet_div(a, b);").unwrap();
writeln!(s, " r = jet_atan(ratio);").unwrap();
writeln!(s, " r.v[0] = atan2(a.v[0], b.v[0]);").unwrap();
writeln!(s, " }}").unwrap();
writeln!(s, " }}").unwrap();
writeln!(s, " case 9u: {{").unwrap();
writeln!(s, " let b = jet_load(j_base + b_idx * K);").unwrap();
writeln!(s, " let aa = abs(a.v[0]);").unwrap();
writeln!(s, " let bb = abs(b.v[0]);").unwrap();
writeln!(s, " let inf = bitcast<f32>(0x7f800000u);").unwrap();
writeln!(s, " if (aa == inf || bb == inf) {{").unwrap();
writeln!(s, " let nan = inf - inf;").unwrap();
writeln!(s, " r.v[0] = inf;").unwrap();
for i in 1..k {
writeln!(s, " r.v[{i}] = nan;").unwrap();
}
writeln!(s, " }} else if (").unwrap();
writeln!(
s,
" (bitcast<u32>(a.v[0]) & 0x7fffffffu) > 0x7f800000u ||"
)
.unwrap();
writeln!(
s,
" (bitcast<u32>(b.v[0]) & 0x7fffffffu) > 0x7f800000u"
)
.unwrap();
writeln!(s, " ) {{").unwrap();
writeln!(s, " r.v[0] = a.v[0] + b.v[0];").unwrap();
for i in 1..k {
writeln!(s, " r.v[{i}] = 0.0;").unwrap();
}
writeln!(s, " }} else {{").unwrap();
writeln!(s, " let h = max(aa, bb);").unwrap();
writeln!(s, " if (h == 0.0) {{").unwrap();
if k >= 2 {
writeln!(s, " var a_shifted: JetK;").unwrap();
writeln!(s, " var b_shifted: JetK;").unwrap();
for i in 0..(k - 1) {
writeln!(
s,
" a_shifted.v[{i}] = a.v[{ip1}];",
ip1 = i + 1
)
.unwrap();
writeln!(
s,
" b_shifted.v[{i}] = b.v[{ip1}];",
ip1 = i + 1
)
.unwrap();
}
writeln!(
s,
" a_shifted.v[{last}] = 0.0;",
last = k - 1
)
.unwrap();
writeln!(
s,
" b_shifted.v[{last}] = 0.0;",
last = k - 1
)
.unwrap();
writeln!(
s,
" let h_inner = max(abs(a_shifted.v[0]), abs(b_shifted.v[0]));"
)
.unwrap();
writeln!(s, " if (h_inner == 0.0) {{").unwrap();
writeln!(s, " r.v[0] = 0.0;").unwrap();
for i in 1..k {
writeln!(s, " r.v[{i}] = inf;").unwrap();
}
writeln!(s, " }} else {{").unwrap();
writeln!(
s,
" let inv_h_inner = 1.0 / h_inner;"
)
.unwrap();
writeln!(
s,
" let a_ss = jet_scale(a_shifted, inv_h_inner);"
)
.unwrap();
writeln!(
s,
" let b_ss = jet_scale(b_shifted, inv_h_inner);"
)
.unwrap();
writeln!(
s,
" let sum_sq_s = jet_add(jet_mul(a_ss, a_ss), jet_mul(b_ss, b_ss));"
)
.unwrap();
writeln!(
s,
" let r_s_inner = jet_sqrt(sum_sq_s);"
)
.unwrap();
writeln!(
s,
" let inner = jet_scale(r_s_inner, h_inner);"
)
.unwrap();
writeln!(s, " r.v[0] = 0.0;").unwrap();
for i in 1..k {
writeln!(
s,
" r.v[{i}] = inner.v[{im1}];",
im1 = i - 1
)
.unwrap();
}
writeln!(s, " }}").unwrap();
} else {
writeln!(s, " r.v[0] = 0.0;").unwrap();
}
writeln!(s, " }} else {{").unwrap();
writeln!(s, " let inv_h = 1.0 / h;").unwrap();
writeln!(s, " let a_s = jet_scale(a, inv_h);").unwrap();
writeln!(s, " let b_s = jet_scale(b, inv_h);").unwrap();
writeln!(
s,
" let sum_sq = jet_add(jet_mul(a_s, a_s), jet_mul(b_s, b_s));"
)
.unwrap();
writeln!(s, " let r_s = jet_sqrt(sum_sq);").unwrap();
writeln!(s, " r = jet_scale(r_s, h);").unwrap();
writeln!(
s,
" let a_s0 = a.v[0] * inv_h; let b_s0 = b.v[0] * inv_h;"
)
.unwrap();
writeln!(
s,
" r.v[0] = h * sqrt(a_s0 * a_s0 + b_s0 * b_s0);"
)
.unwrap();
writeln!(s, " }}").unwrap();
writeln!(s, " }}").unwrap();
writeln!(s, " }}").unwrap();
writeln!(s, " case 10u: {{").unwrap();
writeln!(s, " let b = jet_load(j_base + b_idx * K);").unwrap();
writeln!(
s,
" if a.v[0] >= b.v[0] {{ r = a; }} else {{ r = b; }}"
)
.unwrap();
writeln!(s, " }}").unwrap();
writeln!(s, " case 11u: {{").unwrap();
writeln!(s, " let b = jet_load(j_base + b_idx * K);").unwrap();
writeln!(
s,
" if a.v[0] <= b.v[0] {{ r = a; }} else {{ r = b; }}"
)
.unwrap();
writeln!(s, " }}").unwrap();
writeln!(s, " case 12u: {{ r = jet_neg(a); }}").unwrap();
writeln!(s, " case 13u: {{ r = jet_recip(a); }}").unwrap();
writeln!(s, " case 14u: {{ r = jet_sqrt(a); }}").unwrap();
writeln!(s, " case 15u: {{").unwrap();
writeln!(s, " if a.v[0] == 0.0 {{").unwrap();
writeln!(s, " r.v[0] = 0.0;").unwrap();
for i in 1..k {
writeln!(
s,
" r.v[{i}] = bitcast<f32>(0x7f800000u);"
)
.unwrap();
}
writeln!(s, " }} else {{").unwrap();
writeln!(
s,
" let sg = select(sign(a.v[0]), 1.0, a.v[0] == 0.0);"
)
.unwrap();
write!(
s,
" var abs_a: JetK;\n abs_a.v[0] = abs(a.v[0]);\n"
)
.unwrap();
for i in 1..k {
writeln!(s, " abs_a.v[{i}] = sg * a.v[{i}];").unwrap();
}
writeln!(s, " let lna = jet_ln(abs_a);").unwrap();
writeln!(
s,
" let third = jet_scale(lna, 1.0 / 3.0);"
)
.unwrap();
writeln!(s, " let e = jet_exp(third);").unwrap();
writeln!(s, " r.v[0] = sg * e.v[0];").unwrap();
for i in 1..k {
writeln!(s, " r.v[{i}] = sg * e.v[{i}];").unwrap();
}
writeln!(s, " }}").unwrap();
writeln!(s, " }}").unwrap();
writeln!(s, " case 16u: {{").unwrap();
writeln!(s, " let n = f32(bitcast<i32>(b_idx));").unwrap();
writeln!(s, " let ni = bitcast<i32>(b_idx);").unwrap();
writeln!(s, " if n == 0.0 {{").unwrap();
writeln!(s, " r = jet_const(1.0);").unwrap();
writeln!(s, " }} else if n == 1.0 {{").unwrap();
writeln!(s, " r = a;").unwrap();
writeln!(s, " }} else if a.v[0] == 0.0 && ni == 2 {{").unwrap();
writeln!(s, " r = jet_mul(a, a);").unwrap();
writeln!(s, " }} else if a.v[0] == 0.0 && ni == 3 {{").unwrap();
writeln!(s, " r = jet_mul(jet_mul(a, a), a);").unwrap();
writeln!(s, " }} else if a.v[0] == 0.0 && ni == 4 {{").unwrap();
writeln!(
s,
" let a2 = jet_mul(a, a); r = jet_mul(a2, a2);"
)
.unwrap();
writeln!(s, " }} else if a.v[0] == 0.0 && ni == 5 {{").unwrap();
writeln!(
s,
" let a2 = jet_mul(a, a); let a4 = jet_mul(a2, a2); r = jet_mul(a4, a);"
)
.unwrap();
writeln!(s, " }} else if a.v[0] == 0.0 && ni == 6 {{").unwrap();
writeln!(
s,
" let a2 = jet_mul(a, a); let a4 = jet_mul(a2, a2); r = jet_mul(a4, a2);"
)
.unwrap();
writeln!(s, " }} else if a.v[0] == 0.0 && ni == 7 {{").unwrap();
writeln!(
s,
" let a2 = jet_mul(a, a); let a4 = jet_mul(a2, a2); r = jet_mul(jet_mul(a4, a2), a);"
)
.unwrap();
writeln!(s, " }} else if a.v[0] == 0.0 && ni == 8 {{").unwrap();
writeln!(
s,
" let a2 = jet_mul(a, a); let a4 = jet_mul(a2, a2); r = jet_mul(a4, a4);"
)
.unwrap();
writeln!(s, " }} else if a.v[0] <= 0.0 {{").unwrap();
writeln!(
s,
" let sf = select(1.0, -1.0, ni % 2 != 0);"
)
.unwrap();
writeln!(
s,
" let sg = select(sign(a.v[0]), 1.0, a.v[0] == 0.0);"
)
.unwrap();
write!(
s,
" var abs_a: JetK;\n abs_a.v[0] = abs(a.v[0]);\n"
)
.unwrap();
for i in 1..k {
writeln!(s, " abs_a.v[{i}] = sg * a.v[{i}];").unwrap();
}
writeln!(s, " let lna = jet_ln(abs_a);").unwrap();
writeln!(s, " let nlna = jet_scale(lna, n);").unwrap();
writeln!(s, " let e = jet_exp(nlna);").unwrap();
for i in 0..k {
writeln!(s, " r.v[{i}] = sf * e.v[{i}];").unwrap();
}
writeln!(s, " r.v[0] = pow(a.v[0], n);").unwrap();
writeln!(s, " }} else {{").unwrap();
writeln!(s, " let lna = jet_ln(a);").unwrap();
writeln!(s, " let nlna = jet_scale(lna, n);").unwrap();
writeln!(s, " r = jet_exp(nlna);").unwrap();
writeln!(s, " r.v[0] = pow(a.v[0], n);").unwrap();
writeln!(s, " }}").unwrap();
writeln!(s, " }}").unwrap();
writeln!(s, " case 17u: {{ r = jet_exp(a); }}").unwrap();
writeln!(s, " case 18u: {{").unwrap();
writeln!(s, " let ln2 = log(2.0);").unwrap();
writeln!(s, " let scaled = jet_scale(a, ln2);").unwrap();
writeln!(s, " r = jet_exp(scaled);").unwrap();
writeln!(s, " r.v[0] = exp2(a.v[0]);").unwrap();
writeln!(s, " }}").unwrap();
writeln!(s, " case 19u: {{").unwrap();
writeln!(s, " r = jet_exp(a);").unwrap();
writeln!(s, " r.v[0] = exp(a.v[0]) - 1.0;").unwrap();
writeln!(s, " }}").unwrap();
writeln!(s, " case 20u: {{ r = jet_ln(a); }}").unwrap();
writeln!(s, " case 21u: {{").unwrap();
writeln!(s, " r = jet_ln(a);").unwrap();
writeln!(s, " let inv_ln2 = 1.0 / log(2.0);").unwrap();
writeln!(s, " r.v[0] = log2(a.v[0]);").unwrap();
for i in 1..k {
writeln!(s, " r.v[{i}] = r.v[{i}] * inv_ln2;").unwrap();
}
writeln!(s, " }}").unwrap();
writeln!(s, " case 22u: {{").unwrap();
writeln!(s, " r = jet_ln(a);").unwrap();
writeln!(s, " let inv_ln10 = 1.0 / log(10.0);").unwrap();
writeln!(s, " r.v[0] = log(a.v[0]) * inv_ln10;").unwrap();
for i in 1..k {
writeln!(s, " r.v[{i}] = r.v[{i}] * inv_ln10;").unwrap();
}
writeln!(s, " }}").unwrap();
writeln!(s, " case 23u: {{").unwrap();
write!(
s,
" var one_plus_a: JetK;\n one_plus_a.v[0] = 1.0 + a.v[0];\n"
)
.unwrap();
for i in 1..k {
writeln!(s, " one_plus_a.v[{i}] = a.v[{i}];").unwrap();
}
writeln!(s, " r = jet_ln(one_plus_a);").unwrap();
writeln!(s, " r.v[0] = log(1.0 + a.v[0]);").unwrap();
writeln!(s, " }}").unwrap();
writeln!(
s,
" case 24u: {{ let sc = jet_sin_cos(a); r = sc.a; }}"
)
.unwrap();
writeln!(
s,
" case 25u: {{ let sc = jet_sin_cos(a); r = sc.b; }}"
)
.unwrap();
writeln!(s, " case 26u: {{ r = jet_tan(a); }}").unwrap();
writeln!(s, " case 27u: {{ r = jet_asin(a); }}").unwrap();
writeln!(s, " case 28u: {{ r = jet_acos(a); }}").unwrap();
writeln!(s, " case 29u: {{ r = jet_atan(a); }}").unwrap();
writeln!(
s,
" case 30u: {{ let sc = jet_sinh_cosh(a); r = sc.a; }}"
)
.unwrap();
writeln!(
s,
" case 31u: {{ let sc = jet_sinh_cosh(a); r = sc.b; }}"
)
.unwrap();
writeln!(s, " case 32u: {{ r = jet_tanh(a); }}").unwrap();
writeln!(s, " case 33u: {{ r = jet_asinh(a); }}").unwrap();
writeln!(s, " case 34u: {{ r = jet_acosh(a); }}").unwrap();
writeln!(s, " case 35u: {{ r = jet_atanh(a); }}").unwrap();
writeln!(s, " case 36u: {{").unwrap();
writeln!(s, " var sg: f32 = 1.0;").unwrap();
for i in 0..k {
writeln!(
s,
" if a.v[{i}] != 0.0 {{ sg = select(1.0, -1.0, a.v[{i}] < 0.0); }} else",
)
.unwrap();
}
writeln!(s, " {{ }}").unwrap();
writeln!(s, " r.v[0] = abs(a.v[0]);").unwrap();
for i in 1..k {
writeln!(s, " r.v[{i}] = sg * a.v[{i}];").unwrap();
}
writeln!(s, " }}").unwrap();
writeln!(s, " case 37u, 38u, 39u, 40u, 41u: {{").unwrap();
writeln!(s, " var val = 0.0f;").unwrap();
writeln!(s, " switch op {{").unwrap();
writeln!(
s,
" case 37u: {{ val = select(sign(a.v[0]), 1.0, a.v[0] == 0.0); }}"
)
.unwrap();
writeln!(
s,
" case 38u: {{ val = floor(a.v[0]); }}"
)
.unwrap();
writeln!(s, " case 39u: {{ val = ceil(a.v[0]); }}").unwrap();
writeln!(
s,
" case 40u: {{ val = round(a.v[0]); }}"
)
.unwrap();
writeln!(
s,
" case 41u: {{ val = trunc(a.v[0]); }}"
)
.unwrap();
writeln!(s, " default: {{}}").unwrap();
writeln!(s, " }}").unwrap();
writeln!(s, " r = jet_const(val);").unwrap();
writeln!(s, " }}").unwrap();
writeln!(s, " case 42u: {{").unwrap();
writeln!(s, " r.v[0] = fract(a.v[0]);").unwrap();
for i in 1..k {
writeln!(s, " r.v[{i}] = a.v[{i}];").unwrap();
}
writeln!(s, " }}").unwrap();
writeln!(s, " default: {{}}").unwrap();
writeln!(s, " }}").unwrap();
writeln!(s, " jet_store(j_base + i * K, r);").unwrap();
writeln!(s, " }}").unwrap();
writeln!(s, "\n // Write output jets").unwrap();
writeln!(s, " let out_base = bid * n_out * K;").unwrap();
writeln!(s, " for (var j = 0u; j < n_out; j = j + 1u) {{").unwrap();
writeln!(s, " let oi = output_indices[j];").unwrap();
writeln!(s, " let src = j_base + oi * K;").unwrap();
writeln!(s, " let dst = out_base + j * K;").unwrap();
for c in 0..k {
writeln!(s, " jet_outputs[dst + {c}u] = jets[src + {c}u];").unwrap();
}
writeln!(s, " }}").unwrap();
writeln!(s, "}}").unwrap();
}
fn write_cuda_opcodes(s: &mut String) {
let ops = [
"OP_INPUT",
"OP_CONST",
"OP_ADD",
"OP_SUB",
"OP_MUL",
"OP_DIV",
"OP_REM",
"OP_POWF",
"OP_ATAN2",
"OP_HYPOT",
"OP_MAX",
"OP_MIN",
"OP_NEG",
"OP_RECIP",
"OP_SQRT",
"OP_CBRT",
"OP_POWI",
"OP_EXP",
"OP_EXP2",
"OP_EXPM1",
"OP_LN",
"OP_LOG2",
"OP_LOG10",
"OP_LN1P",
"OP_SIN",
"OP_COS",
"OP_TAN",
"OP_ASIN",
"OP_ACOS",
"OP_ATAN",
"OP_SINH",
"OP_COSH",
"OP_TANH",
"OP_ASINH",
"OP_ACOSH",
"OP_ATANH",
"OP_ABS",
"OP_SIGNUM",
"OP_FLOOR",
"OP_CEIL",
"OP_ROUND",
"OP_TRUNC",
"OP_FRACT",
];
for (i, name) in ops.iter().enumerate() {
writeln!(s, "#define {name} {i}").unwrap();
}
writeln!(s).unwrap();
}
fn write_cuda_helpers(s: &mut String) {
writeln!(
s,
"__device__ F _sign(F x) {{ return (x != x) ? x : copysign((F)1, x); }}
__device__ F _cbrt_f(F x) {{ return (x >= (F)0) ? pow(x, (F)(1.0/3.0)) : -pow(-x, (F)(1.0/3.0)); }}
__device__ F _fract(F x) {{ return x - floor(x); }}
"
)
.unwrap();
}
fn write_cuda_jet_type(s: &mut String, k: usize) {
writeln!(
s,
"struct JetK {{
F v[{k}];
__device__ JetK() {{ for(int i=0;i<{k};i++) v[i]=(F)0; }}
}};
"
)
.unwrap();
writeln!(s, "__device__ JetK jet_const(F val) {{").unwrap();
writeln!(s, " JetK j;").unwrap();
writeln!(s, " j.v[0] = val;").unwrap();
writeln!(s, " return j;\n}}\n").unwrap();
}
fn write_cuda_jet_arithmetic(s: &mut String, k: usize) {
writeln!(s, "__device__ JetK jet_add(JetK a, JetK b) {{").unwrap();
writeln!(s, " JetK c;").unwrap();
for i in 0..k {
writeln!(s, " c.v[{i}] = a.v[{i}] + b.v[{i}];").unwrap();
}
writeln!(s, " return c;\n}}\n").unwrap();
writeln!(s, "__device__ JetK jet_sub(JetK a, JetK b) {{").unwrap();
writeln!(s, " JetK c;").unwrap();
for i in 0..k {
writeln!(s, " c.v[{i}] = a.v[{i}] - b.v[{i}];").unwrap();
}
writeln!(s, " return c;\n}}\n").unwrap();
writeln!(s, "__device__ JetK jet_neg(JetK a) {{").unwrap();
writeln!(s, " JetK c;").unwrap();
for i in 0..k {
writeln!(s, " c.v[{i}] = -a.v[{i}];").unwrap();
}
writeln!(s, " return c;\n}}\n").unwrap();
writeln!(s, "__device__ JetK jet_scale(JetK a, F s) {{").unwrap();
writeln!(s, " JetK c;").unwrap();
for i in 0..k {
writeln!(s, " c.v[{i}] = a.v[{i}] * s;").unwrap();
}
writeln!(s, " return c;\n}}\n").unwrap();
writeln!(s, "__device__ JetK jet_mul(JetK a, JetK b) {{").unwrap();
writeln!(s, " JetK c;").unwrap();
for i in 0..k {
let mut terms = Vec::new();
for j in 0..=i {
terms.push(format!("a.v[{j}] * b.v[{}]", i - j));
}
writeln!(s, " c.v[{i}] = {};", terms.join(" + ")).unwrap();
}
writeln!(s, " return c;\n}}\n").unwrap();
writeln!(s, "__device__ JetK jet_div(JetK a, JetK b) {{").unwrap();
writeln!(s, " JetK c;").unwrap();
writeln!(s, " F inv_b0 = (F)1 / b.v[0];").unwrap();
writeln!(s, " c.v[0] = a.v[0] * inv_b0;").unwrap();
for i in 1..k {
let mut terms = Vec::new();
for j in 1..=i {
terms.push(format!("b.v[{j}] * c.v[{}]", i - j));
}
writeln!(
s,
" c.v[{i}] = (a.v[{i}] - ({})) * inv_b0;",
terms.join(" + ")
)
.unwrap();
}
writeln!(s, " return c;\n}}\n").unwrap();
writeln!(s, "__device__ JetK jet_recip(JetK a) {{").unwrap();
writeln!(s, " JetK c;").unwrap();
writeln!(s, " c.v[0] = (F)1 / a.v[0];").unwrap();
for i in 1..k {
let mut terms = Vec::new();
for j in 1..=i {
terms.push(format!("a.v[{j}] * c.v[{}]", i - j));
}
writeln!(s, " c.v[{i}] = -({}) * c.v[0];", terms.join(" + ")).unwrap();
}
writeln!(s, " return c;\n}}\n").unwrap();
}
fn write_cuda_jet_transcendental(s: &mut String, k: usize) {
writeln!(s, "__device__ JetK jet_exp(JetK a) {{").unwrap();
writeln!(s, " JetK c;").unwrap();
writeln!(s, " c.v[0] = exp(a.v[0]);").unwrap();
for i in 1..k {
let inv_i = 1.0 / i as f64;
let mut terms = Vec::new();
for j in 1..=i {
terms.push(format!("(F){:.1} * a.v[{j}] * c.v[{}]", j as f64, i - j));
}
writeln!(
s,
" c.v[{i}] = (F){inv_i:.10} * ({});",
terms.join(" + ")
)
.unwrap();
}
writeln!(s, " return c;\n}}\n").unwrap();
writeln!(s, "__device__ JetK jet_ln(JetK a) {{").unwrap();
writeln!(s, " JetK c;").unwrap();
writeln!(s, " F inv_a0 = (F)1 / a.v[0];").unwrap();
writeln!(s, " c.v[0] = log(a.v[0]);").unwrap();
for i in 1..k {
if i == 1 {
writeln!(s, " c.v[1] = a.v[1] * inv_a0;").unwrap();
} else {
let inv_i = 1.0 / i as f64;
let mut terms = Vec::new();
for j in 1..i {
terms.push(format!("(F){:.1} * c.v[{j}] * a.v[{}]", j as f64, i - j));
}
writeln!(
s,
" c.v[{i}] = (a.v[{i}] - (F){inv_i:.10} * ({})) * inv_a0;",
terms.join(" + ")
)
.unwrap();
}
}
writeln!(s, " return c;\n}}\n").unwrap();
writeln!(s, "__device__ JetK jet_sqrt(JetK a) {{").unwrap();
writeln!(s, " JetK c;").unwrap();
writeln!(s, " c.v[0] = sqrt(a.v[0]);").unwrap();
if k > 1 {
writeln!(s, " F inv_2c0 = (F)0.5 / c.v[0];").unwrap();
}
for i in 1..k {
if i == 1 {
writeln!(s, " c.v[1] = a.v[1] * inv_2c0;").unwrap();
} else {
let mut terms = Vec::new();
for j in 1..i {
terms.push(format!("c.v[{j}] * c.v[{}]", i - j));
}
writeln!(
s,
" c.v[{i}] = (a.v[{i}] - ({})) * inv_2c0;",
terms.join(" + ")
)
.unwrap();
}
}
writeln!(s, " return c;\n}}\n").unwrap();
writeln!(s, "struct JetPair {{ JetK a; JetK b; }};\n").unwrap();
writeln!(s, "__device__ JetPair jet_sin_cos(JetK a) {{").unwrap();
writeln!(s, " JetK sn, co;").unwrap();
writeln!(s, " sn.v[0] = sin(a.v[0]);").unwrap();
writeln!(s, " co.v[0] = cos(a.v[0]);").unwrap();
for i in 1..k {
let inv_i = 1.0 / i as f64;
let mut sterms = Vec::new();
let mut cterms = Vec::new();
for j in 1..=i {
sterms.push(format!("(F){:.1} * a.v[{j}] * co.v[{}]", j as f64, i - j));
cterms.push(format!("(F){:.1} * a.v[{j}] * sn.v[{}]", j as f64, i - j));
}
writeln!(
s,
" sn.v[{i}] = (F){inv_i:.10} * ({});",
sterms.join(" + ")
)
.unwrap();
writeln!(
s,
" co.v[{i}] = -(F){inv_i:.10} * ({});",
cterms.join(" + ")
)
.unwrap();
}
writeln!(s, " JetPair p; p.a = sn; p.b = co;").unwrap();
writeln!(s, " return p;\n}}\n").unwrap();
writeln!(s, "__device__ JetPair jet_sinh_cosh(JetK a) {{").unwrap();
writeln!(s, " JetK sh, ch;").unwrap();
writeln!(s, " sh.v[0] = sinh(a.v[0]);").unwrap();
writeln!(s, " ch.v[0] = cosh(a.v[0]);").unwrap();
for i in 1..k {
let inv_i = 1.0 / i as f64;
let mut shterms = Vec::new();
let mut chterms = Vec::new();
for j in 1..=i {
shterms.push(format!("(F){:.1} * a.v[{j}] * ch.v[{}]", j as f64, i - j));
chterms.push(format!("(F){:.1} * a.v[{j}] * sh.v[{}]", j as f64, i - j));
}
writeln!(
s,
" sh.v[{i}] = (F){inv_i:.10} * ({});",
shterms.join(" + ")
)
.unwrap();
writeln!(
s,
" ch.v[{i}] = (F){inv_i:.10} * ({});",
chterms.join(" + ")
)
.unwrap();
}
writeln!(s, " JetPair p; p.a = sh; p.b = ch;").unwrap();
writeln!(s, " return p;\n}}\n").unwrap();
writeln!(s, "__device__ JetK jet_tan(JetK a) {{").unwrap();
writeln!(s, " JetK c, sc;").unwrap();
writeln!(s, " c.v[0] = tan(a.v[0]);").unwrap();
writeln!(s, " sc.v[0] = (F)1 + c.v[0] * c.v[0];").unwrap();
for i in 1..k {
let inv_i = 1.0 / i as f64;
let mut terms = Vec::new();
for j in 1..=i {
terms.push(format!("(F){:.1} * a.v[{j}] * sc.v[{}]", j as f64, i - j));
}
writeln!(
s,
" c.v[{i}] = (F){inv_i:.10} * ({});",
terms.join(" + ")
)
.unwrap();
let mut sc_terms = Vec::new();
for j in 0..=i {
sc_terms.push(format!("c.v[{j}] * c.v[{}]", i - j));
}
writeln!(s, " sc.v[{i}] = {};", sc_terms.join(" + ")).unwrap();
}
writeln!(s, " return c;\n}}\n").unwrap();
writeln!(s, "__device__ JetK jet_tanh(JetK a) {{").unwrap();
writeln!(s, " JetK c, sc;").unwrap();
writeln!(s, " c.v[0] = tanh(a.v[0]);").unwrap();
writeln!(s, " sc.v[0] = (F)1 - c.v[0] * c.v[0];").unwrap();
for i in 1..k {
let inv_i = 1.0 / i as f64;
let mut terms = Vec::new();
for j in 1..=i {
terms.push(format!("(F){:.1} * a.v[{j}] * sc.v[{}]", j as f64, i - j));
}
writeln!(
s,
" c.v[{i}] = (F){inv_i:.10} * ({});",
terms.join(" + ")
)
.unwrap();
let mut sc_terms = Vec::new();
for j in 0..=i {
sc_terms.push(format!("c.v[{j}] * c.v[{}]", i - j));
}
writeln!(s, " sc.v[{i}] = -({});", sc_terms.join(" + ")).unwrap();
}
writeln!(s, " return c;\n}}\n").unwrap();
}
fn write_cuda_jet_inverse_trig(s: &mut String, k: usize) {
let inv_trig_fns = [
("atan", "1.0 + asq", false, "atan(a.v[0])", false), ("asin", "1.0 - asq", true, "asin(a.v[0])", false), ("acos", "1.0 - asq", true, "acos(a.v[0])", true), ("asinh", "1.0 + asq", true, "asinh(a.v[0])", false), ("acosh", "asq - 1.0", true, "acosh(a.v[0])", false), ("atanh", "1.0 - asq", false, "atanh(a.v[0])", false), ];
for (name, d_expr, use_sqrt, c0_expr, negate) in &inv_trig_fns {
writeln!(s, "__device__ JetK jet_{name}(JetK a) {{").unwrap();
writeln!(s, " JetK asq = jet_mul(a, a);").unwrap();
writeln!(s, " JetK d;").unwrap();
if d_expr.starts_with("1.0 +") {
writeln!(s, " d.v[0] = (F)1 + asq.v[0];").unwrap();
for i in 1..k {
writeln!(s, " d.v[{i}] = asq.v[{i}];").unwrap();
}
} else if d_expr.starts_with("1.0 -") {
writeln!(s, " d.v[0] = (F)1 - asq.v[0];").unwrap();
for i in 1..k {
writeln!(s, " d.v[{i}] = -asq.v[{i}];").unwrap();
}
} else {
writeln!(s, " d.v[0] = (a.v[0] - (F)1) * (a.v[0] + (F)1);").unwrap();
for i in 1..k {
writeln!(s, " d.v[{i}] = asq.v[{i}];").unwrap();
}
}
if *use_sqrt {
writeln!(s, " JetK g = jet_recip(jet_sqrt(d));").unwrap();
} else {
writeln!(s, " JetK g = jet_recip(d);").unwrap();
}
writeln!(s, " JetK c;").unwrap();
writeln!(s, " c.v[0] = {c0_expr};").unwrap();
let sign_str = if *negate { "-" } else { "" };
for i in 1..k {
let inv_i = 1.0 / i as f64;
let mut terms = Vec::new();
for j in 1..=i {
terms.push(format!("(F){:.1} * a.v[{j}] * g.v[{}]", j as f64, i - j));
}
writeln!(
s,
" c.v[{i}] = {sign_str}(F){inv_i:.10} * ({});",
terms.join(" + ")
)
.unwrap();
}
writeln!(s, " return c;\n}}\n").unwrap();
}
}
fn write_cuda_main_kernel(s: &mut String, k: usize) {
writeln!(s, "extern \"C\" __global__ void taylor_forward_kth(").unwrap();
writeln!(s, " const unsigned int* __restrict__ opcodes,").unwrap();
writeln!(s, " const unsigned int* __restrict__ arg0,").unwrap();
writeln!(s, " const unsigned int* __restrict__ arg1,").unwrap();
writeln!(s, " const F* __restrict__ constants,").unwrap();
writeln!(s, " const F* __restrict__ primal_inputs,").unwrap();
writeln!(s, " const F* __restrict__ direction_seeds,").unwrap();
writeln!(s, " F* __restrict__ jets,").unwrap();
writeln!(s, " F* __restrict__ jet_outputs,").unwrap();
writeln!(s, " const unsigned int* __restrict__ output_indices,").unwrap();
writeln!(s, " unsigned int num_ops,").unwrap();
writeln!(s, " unsigned int num_inputs,").unwrap();
writeln!(s, " unsigned int num_variables,").unwrap();
writeln!(s, " unsigned int num_outputs,").unwrap();
writeln!(s, " unsigned int batch_size").unwrap();
writeln!(s, ") {{").unwrap();
writeln!(
s,
" unsigned int bid = blockIdx.x * blockDim.x + threadIdx.x;"
)
.unwrap();
writeln!(s, " if (bid >= batch_size) return;").unwrap();
writeln!(s, " const unsigned int K = {k};").unwrap();
writeln!(
s,
" unsigned long long j_base = (unsigned long long)bid * num_variables * K;"
)
.unwrap();
writeln!(s, " for (unsigned int i = 0; i < num_variables; i++) {{").unwrap();
writeln!(s, " unsigned long long off = j_base + i * K;").unwrap();
writeln!(s, " jets[off] = constants[i];").unwrap();
for c in 1..k {
writeln!(s, " jets[off + {c}] = (F)0;").unwrap();
}
writeln!(s, " }}").unwrap();
writeln!(
s,
" unsigned long long in_base = (unsigned long long)bid * num_inputs;"
)
.unwrap();
writeln!(s, " for (unsigned int i = 0; i < num_inputs; i++) {{").unwrap();
writeln!(s, " unsigned long long off = j_base + i * K;").unwrap();
writeln!(s, " jets[off] = primal_inputs[in_base + i];").unwrap();
if k > 1 {
writeln!(s, " jets[off + 1] = direction_seeds[in_base + i];").unwrap();
}
writeln!(s, " }}").unwrap();
writeln!(
s,
" for (unsigned int i = num_inputs; i < num_ops; i++) {{"
)
.unwrap();
writeln!(s, " unsigned int op = opcodes[i];").unwrap();
writeln!(s, " if (op == OP_CONST) continue;").unwrap();
writeln!(s, " unsigned int a_idx = arg0[i];").unwrap();
writeln!(s, " unsigned int b_idx = arg1[i];").unwrap();
writeln!(s, " JetK a;").unwrap();
writeln!(
s,
" unsigned long long a_off = j_base + (unsigned long long)a_idx * K;"
)
.unwrap();
for c in 0..k {
writeln!(s, " a.v[{c}] = jets[a_off + {c}];").unwrap();
}
writeln!(s, " JetK r;").unwrap();
writeln!(s, " switch (op) {{").unwrap();
for (case, func) in &[
(2, "jet_add"),
(3, "jet_sub"),
(4, "jet_mul"),
(5, "jet_div"),
] {
writeln!(s, " case {case}: {{").unwrap();
writeln!(
s,
" JetK b; unsigned long long b_off = j_base + (unsigned long long)b_idx * K;"
)
.unwrap();
for c in 0..k {
writeln!(s, " b.v[{c}] = jets[b_off + {c}];").unwrap();
}
writeln!(s, " r = {func}(a, b); break;").unwrap();
writeln!(s, " }}").unwrap();
}
writeln!(s, " case 6: {{").unwrap();
writeln!(
s,
" JetK b; unsigned long long b_off = j_base + (unsigned long long)b_idx * K;"
)
.unwrap();
for c in 0..k {
writeln!(s, " b.v[{c}] = jets[b_off + {c}];").unwrap();
}
writeln!(s, " F q = trunc(a.v[0] / b.v[0]);").unwrap();
writeln!(s, " r.v[0] = fmod(a.v[0], b.v[0]);").unwrap();
for i in 1..k {
writeln!(s, " r.v[{i}] = a.v[{i}] - q * b.v[{i}];").unwrap();
}
writeln!(s, " break;").unwrap();
writeln!(s, " }}").unwrap();
writeln!(s, " case 7: {{").unwrap();
writeln!(
s,
" JetK b; unsigned long long b_off = j_base + (unsigned long long)b_idx * K;"
)
.unwrap();
for c in 0..k {
writeln!(s, " b.v[{c}] = jets[b_off + {c}];").unwrap();
}
writeln!(s, " if (a.v[0] <= F(0)) {{").unwrap();
writeln!(s, " F val = pow(a.v[0], b.v[0]);").unwrap();
writeln!(
s,
" F da = b.v[0] * pow(a.v[0], b.v[0] - F(1));"
)
.unwrap();
writeln!(s, " F db = F(0);").unwrap();
writeln!(s, " r.v[0] = val;").unwrap();
if k > 1 {
writeln!(s, " r.v[1] = da * a.v[1] + db * b.v[1];").unwrap();
}
for i in 2..k {
writeln!(s, " r.v[{i}] = F(0.0/0.0);").unwrap();
}
writeln!(s, " }} else {{").unwrap();
writeln!(s, " JetK lna = jet_ln(a);").unwrap();
writeln!(s, " JetK product = jet_mul(b, lna);").unwrap();
writeln!(s, " r = jet_exp(product);").unwrap();
writeln!(s, " r.v[0] = pow(a.v[0], b.v[0]);").unwrap();
writeln!(s, " }}").unwrap();
writeln!(s, " break;").unwrap();
writeln!(s, " }}").unwrap();
writeln!(s, " case 8: {{").unwrap();
writeln!(
s,
" JetK b; unsigned long long b_off = j_base + (unsigned long long)b_idx * K;"
)
.unwrap();
for c in 0..k {
writeln!(s, " b.v[{c}] = jets[b_off + {c}];").unwrap();
}
writeln!(s, " if (b.v[0] == F(0)) {{").unwrap();
writeln!(s, " r.v[0] = atan2(a.v[0], b.v[0]);").unwrap();
writeln!(s, " if (a.v[0] == F(0)) {{").unwrap();
for i in 1..k {
writeln!(s, " r.v[{i}] = F(0);").unwrap();
}
writeln!(s, " }} else {{").unwrap();
writeln!(s, " JetK ratio = jet_div(b, a);").unwrap();
writeln!(s, " JetK c = jet_atan(ratio);").unwrap();
for i in 1..k {
writeln!(s, " r.v[{i}] = -c.v[{i}];").unwrap();
}
writeln!(s, " }}").unwrap();
writeln!(s, " }} else {{").unwrap();
writeln!(s, " r = jet_atan(jet_div(a, b));").unwrap();
writeln!(s, " r.v[0] = atan2(a.v[0], b.v[0]);").unwrap();
writeln!(s, " }}").unwrap();
writeln!(s, " break;").unwrap();
writeln!(s, " }}").unwrap();
writeln!(s, " case 9: {{").unwrap();
writeln!(
s,
" JetK b; unsigned long long b_off = j_base + (unsigned long long)b_idx * K;"
)
.unwrap();
for c in 0..k {
writeln!(s, " b.v[{c}] = jets[b_off + {c}];").unwrap();
}
writeln!(s, " F aa = fabs(a.v[0]);").unwrap();
writeln!(s, " F bb = fabs(b.v[0]);").unwrap();
writeln!(s, " if (isinf(aa) || isinf(bb)) {{").unwrap();
writeln!(s, " F inf_sentinel = fmax(aa, bb);").unwrap();
writeln!(
s,
" F nan_sentinel = inf_sentinel - inf_sentinel;"
)
.unwrap();
writeln!(s, " r.v[0] = inf_sentinel;").unwrap();
for i in 1..k {
writeln!(s, " r.v[{i}] = nan_sentinel;").unwrap();
}
writeln!(s, " }} else if (isnan(aa) || isnan(bb)) {{").unwrap();
writeln!(s, " r.v[0] = a.v[0] + b.v[0];").unwrap();
for i in 1..k {
writeln!(s, " r.v[{i}] = (F)0;").unwrap();
}
writeln!(s, " }} else {{").unwrap();
writeln!(s, " F h = fmax(aa, bb);").unwrap();
writeln!(s, " if (h == (F)0) {{").unwrap();
if k >= 2 {
writeln!(s, " JetK a_shifted;").unwrap();
writeln!(s, " JetK b_shifted;").unwrap();
for i in 0..(k - 1) {
writeln!(
s,
" a_shifted.v[{i}] = a.v[{ip1}];",
ip1 = i + 1
)
.unwrap();
writeln!(
s,
" b_shifted.v[{i}] = b.v[{ip1}];",
ip1 = i + 1
)
.unwrap();
}
writeln!(
s,
" a_shifted.v[{last}] = (F)0;",
last = k - 1
)
.unwrap();
writeln!(
s,
" b_shifted.v[{last}] = (F)0;",
last = k - 1
)
.unwrap();
writeln!(
s,
" F h_inner = fmax(fabs(a_shifted.v[0]), fabs(b_shifted.v[0]));"
)
.unwrap();
writeln!(s, " if (h_inner == (F)0) {{").unwrap();
writeln!(s, " F pos_inf = (F)1 / (F)0;").unwrap();
writeln!(s, " r.v[0] = (F)0;").unwrap();
for i in 1..k {
writeln!(s, " r.v[{i}] = pos_inf;").unwrap();
}
writeln!(s, " }} else {{").unwrap();
writeln!(s, " F inv_h_inner = (F)1 / h_inner;").unwrap();
writeln!(
s,
" JetK a_ss = jet_scale(a_shifted, inv_h_inner);"
)
.unwrap();
writeln!(
s,
" JetK b_ss = jet_scale(b_shifted, inv_h_inner);"
)
.unwrap();
writeln!(
s,
" JetK sum_sq_s = jet_add(jet_mul(a_ss, a_ss), jet_mul(b_ss, b_ss));"
)
.unwrap();
writeln!(
s,
" JetK r_s_inner = jet_sqrt(sum_sq_s);"
)
.unwrap();
writeln!(
s,
" JetK inner = jet_scale(r_s_inner, h_inner);"
)
.unwrap();
writeln!(s, " r.v[0] = (F)0;").unwrap();
for i in 1..k {
writeln!(
s,
" r.v[{i}] = inner.v[{im1}];",
im1 = i - 1
)
.unwrap();
}
writeln!(s, " }}").unwrap();
} else {
writeln!(s, " r.v[0] = (F)0;").unwrap();
}
writeln!(s, " }} else {{").unwrap();
writeln!(s, " F inv_h = (F)1 / h;").unwrap();
writeln!(s, " JetK a_s = jet_scale(a, inv_h);").unwrap();
writeln!(s, " JetK b_s = jet_scale(b, inv_h);").unwrap();
writeln!(
s,
" JetK sum_sq = jet_add(jet_mul(a_s, a_s), jet_mul(b_s, b_s));"
)
.unwrap();
writeln!(s, " JetK r_s = jet_sqrt(sum_sq);").unwrap();
writeln!(s, " r = jet_scale(r_s, h);").unwrap();
writeln!(s, " r.v[0] = hypot(a.v[0], b.v[0]);").unwrap();
writeln!(s, " }}").unwrap();
writeln!(s, " }}").unwrap();
writeln!(s, " break;").unwrap();
writeln!(s, " }}").unwrap();
for (case, cmp) in &[(10, ">="), (11, "<=")] {
writeln!(s, " case {case}: {{").unwrap();
writeln!(
s,
" JetK b; unsigned long long b_off = j_base + (unsigned long long)b_idx * K;"
)
.unwrap();
for c in 0..k {
writeln!(s, " b.v[{c}] = jets[b_off + {c}];").unwrap();
}
writeln!(s, " r = (a.v[0] {cmp} b.v[0]) ? a : b; break;").unwrap();
writeln!(s, " }}").unwrap();
}
writeln!(s, " case 12: r = jet_neg(a); break;").unwrap();
writeln!(s, " case 13: r = jet_recip(a); break;").unwrap();
writeln!(s, " case 14: r = jet_sqrt(a); break;").unwrap();
writeln!(s, " case 15: {{").unwrap();
writeln!(s, " if (a.v[0] == F(0)) {{").unwrap();
writeln!(s, " r.v[0] = F(0);").unwrap();
for i in 1..k {
writeln!(s, " r.v[{i}] = F(1.0/0.0);").unwrap();
}
writeln!(s, " }} else {{").unwrap();
writeln!(s, " F sg = _sign(a.v[0]);").unwrap();
writeln!(s, " JetK abs_a; abs_a.v[0] = fabs(a.v[0]);").unwrap();
for i in 1..k {
writeln!(s, " abs_a.v[{i}] = sg * a.v[{i}];").unwrap();
}
writeln!(
s,
" JetK e = jet_exp(jet_scale(jet_ln(abs_a), (F)(1.0/3.0)));"
)
.unwrap();
for i in 0..k {
writeln!(s, " r.v[{i}] = sg * e.v[{i}];").unwrap();
}
writeln!(s, " }}").unwrap();
writeln!(s, " break;").unwrap();
writeln!(s, " }}").unwrap();
writeln!(s, " case 16: {{").unwrap();
writeln!(s, " int ni = (int)b_idx;").unwrap();
writeln!(s, " F n = (F)ni;").unwrap();
writeln!(s, " if (ni == 0) {{ r = jet_const((F)1); }}").unwrap();
writeln!(s, " else if (ni == 1) {{ r = a; }}").unwrap();
writeln!(
s,
" else if (a.v[0] == F(0) && ni == 2) {{ r = jet_mul(a, a); }}"
)
.unwrap();
writeln!(
s,
" else if (a.v[0] == F(0) && ni == 3) {{ r = jet_mul(jet_mul(a, a), a); }}"
)
.unwrap();
writeln!(
s,
" else if (a.v[0] == F(0) && ni == 4) {{ JetK a2 = jet_mul(a, a); r = jet_mul(a2, a2); }}"
)
.unwrap();
writeln!(
s,
" else if (a.v[0] == F(0) && ni == 5) {{ JetK a2 = jet_mul(a, a); JetK a4 = jet_mul(a2, a2); r = jet_mul(a4, a); }}"
)
.unwrap();
writeln!(
s,
" else if (a.v[0] == F(0) && ni == 6) {{ JetK a2 = jet_mul(a, a); JetK a4 = jet_mul(a2, a2); r = jet_mul(a4, a2); }}"
)
.unwrap();
writeln!(
s,
" else if (a.v[0] == F(0) && ni == 7) {{ JetK a2 = jet_mul(a, a); JetK a4 = jet_mul(a2, a2); r = jet_mul(jet_mul(a4, a2), a); }}"
)
.unwrap();
writeln!(
s,
" else if (a.v[0] == F(0) && ni == 8) {{ JetK a2 = jet_mul(a, a); JetK a4 = jet_mul(a2, a2); r = jet_mul(a4, a4); }}"
)
.unwrap();
writeln!(s, " else if (a.v[0] <= F(0)) {{").unwrap();
writeln!(s, " F sf = (ni % 2 == 0) ? F(1) : F(-1);").unwrap();
writeln!(s, " F sg = _sign(a.v[0]);").unwrap();
writeln!(s, " JetK abs_a; abs_a.v[0] = fabs(a.v[0]);").unwrap();
for i in 1..k {
writeln!(s, " abs_a.v[{i}] = sg * a.v[{i}];").unwrap();
}
writeln!(
s,
" JetK e = jet_exp(jet_scale(jet_ln(abs_a), n));"
)
.unwrap();
for i in 0..k {
writeln!(s, " r.v[{i}] = sf * e.v[{i}];").unwrap();
}
writeln!(s, " r.v[0] = pow(a.v[0], n);").unwrap();
writeln!(s, " }}").unwrap();
writeln!(
s,
" else {{ r = jet_exp(jet_scale(jet_ln(a), n)); r.v[0] = pow(a.v[0], n); }}"
)
.unwrap();
writeln!(s, " break;").unwrap();
writeln!(s, " }}").unwrap();
writeln!(s, " case 17: r = jet_exp(a); break;").unwrap();
writeln!(s, " case 18: {{ r = jet_exp(jet_scale(a, log((F)2))); r.v[0] = exp2(a.v[0]); break; }}").unwrap();
writeln!(
s,
" case 19: {{ r = jet_exp(a); r.v[0] = exp(a.v[0]) - (F)1; break; }}"
)
.unwrap();
writeln!(s, " case 20: r = jet_ln(a); break;").unwrap();
writeln!(s, " case 21: {{").unwrap();
writeln!(s, " r = jet_ln(a);").unwrap();
writeln!(s, " F inv_ln2 = (F)1 / log((F)2);").unwrap();
writeln!(s, " r.v[0] = log2(a.v[0]);").unwrap();
for i in 1..k {
writeln!(s, " r.v[{i}] *= inv_ln2;").unwrap();
}
writeln!(s, " break;").unwrap();
writeln!(s, " }}").unwrap();
writeln!(s, " case 22: {{").unwrap();
writeln!(s, " r = jet_ln(a);").unwrap();
writeln!(s, " F inv_ln10 = (F)1 / log((F)10);").unwrap();
writeln!(s, " r.v[0] = log(a.v[0]) * inv_ln10;").unwrap();
for i in 1..k {
writeln!(s, " r.v[{i}] *= inv_ln10;").unwrap();
}
writeln!(s, " break;").unwrap();
writeln!(s, " }}").unwrap();
writeln!(s, " case 23: {{").unwrap();
writeln!(s, " JetK opa; opa.v[0] = (F)1 + a.v[0];").unwrap();
for i in 1..k {
writeln!(s, " opa.v[{i}] = a.v[{i}];").unwrap();
}
writeln!(
s,
" r = jet_ln(opa); r.v[0] = log((F)1 + a.v[0]); break;"
)
.unwrap();
writeln!(s, " }}").unwrap();
writeln!(
s,
" case 24: {{ JetPair sc = jet_sin_cos(a); r = sc.a; break; }}"
)
.unwrap();
writeln!(
s,
" case 25: {{ JetPair sc = jet_sin_cos(a); r = sc.b; break; }}"
)
.unwrap();
writeln!(s, " case 26: r = jet_tan(a); break;").unwrap();
writeln!(s, " case 27: r = jet_asin(a); break;").unwrap();
writeln!(s, " case 28: r = jet_acos(a); break;").unwrap();
writeln!(s, " case 29: r = jet_atan(a); break;").unwrap();
writeln!(
s,
" case 30: {{ JetPair sc = jet_sinh_cosh(a); r = sc.a; break; }}"
)
.unwrap();
writeln!(
s,
" case 31: {{ JetPair sc = jet_sinh_cosh(a); r = sc.b; break; }}"
)
.unwrap();
writeln!(s, " case 32: r = jet_tanh(a); break;").unwrap();
writeln!(s, " case 33: r = jet_asinh(a); break;").unwrap();
writeln!(s, " case 34: r = jet_acosh(a); break;").unwrap();
writeln!(s, " case 35: r = jet_atanh(a); break;").unwrap();
writeln!(s, " case 36: {{").unwrap();
writeln!(s, " F sg = F(1);").unwrap();
for i in 0..k {
writeln!(
s,
" if (a.v[{i}] != F(0)) {{ sg = (a.v[{i}] < F(0)) ? F(-1) : F(1); }} else",
)
.unwrap();
}
writeln!(
s,
" {{ /* identically zero series → sg = 1 */ }}"
)
.unwrap();
writeln!(s, " r.v[0] = fabs(a.v[0]);").unwrap();
for i in 1..k {
writeln!(s, " r.v[{i}] = sg * a.v[{i}];").unwrap();
}
writeln!(s, " break;").unwrap();
writeln!(s, " }}").unwrap();
writeln!(s, " case 37: r = jet_const(_sign(a.v[0])); break;").unwrap();
writeln!(s, " case 38: r = jet_const(floor(a.v[0])); break;").unwrap();
writeln!(s, " case 39: r = jet_const(ceil(a.v[0])); break;").unwrap();
writeln!(s, " case 40: r = jet_const(round(a.v[0])); break;").unwrap();
writeln!(s, " case 41: r = jet_const(trunc(a.v[0])); break;").unwrap();
writeln!(s, " case 42: {{").unwrap();
writeln!(s, " r.v[0] = _fract(a.v[0]);").unwrap();
for i in 1..k {
writeln!(s, " r.v[{i}] = a.v[{i}];").unwrap();
}
writeln!(s, " break;").unwrap();
writeln!(s, " }}").unwrap();
writeln!(s, " default: break;").unwrap();
writeln!(s, " }}").unwrap();
writeln!(s, " unsigned long long r_off = j_base + i * K;").unwrap();
for c in 0..k {
writeln!(s, " jets[r_off + {c}] = r.v[{c}];").unwrap();
}
writeln!(s, " }}").unwrap();
writeln!(
s,
" unsigned long long out_base = (unsigned long long)bid * num_outputs * K;"
)
.unwrap();
writeln!(s, " for (unsigned int j = 0; j < num_outputs; j++) {{").unwrap();
writeln!(s, " unsigned int oi = output_indices[j];").unwrap();
writeln!(
s,
" unsigned long long src = j_base + (unsigned long long)oi * K;"
)
.unwrap();
writeln!(s, " unsigned long long dst = out_base + j * K;").unwrap();
for c in 0..k {
writeln!(s, " jet_outputs[dst + {c}] = jets[src + {c}];").unwrap();
}
writeln!(s, " }}").unwrap();
writeln!(s, "}}").unwrap();
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn wgsl_k3_compiles() {
let shader = generate_taylor_wgsl(3);
assert!(shader.contains("struct JetK { v: array<f32, 3>, }"));
assert!(shader.contains("fn jet_mul"));
assert!(shader.contains("fn jet_exp"));
assert!(shader.contains("fn main"));
}
#[test]
fn cuda_k3_compiles() {
let kernel = generate_taylor_cuda(3);
assert!(kernel.contains("F v[3]"));
assert!(kernel.contains("jet_mul"));
assert!(kernel.contains("jet_exp"));
assert!(kernel.contains("taylor_forward_kth"));
}
#[test]
fn all_k_values_generate() {
for k in 1..=5 {
let wgsl = generate_taylor_wgsl(k);
assert!(wgsl.contains(&format!("array<f32, {k}>")));
let cuda = generate_taylor_cuda(k);
assert!(cuda.contains(&format!("F v[{k}]")));
}
}
}