#version 430
// Runge-Kutta 4th order - Final combination compute shader
// Computes y_new = y + (k1 + 2*k2 + 2*k3 + k4) / 6
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout(std430, binding = 0) buffer YBuffer {
float y[];
};
layout(std430, binding = 1) buffer K1Buffer {
float k1[];
};
layout(std430, binding = 2) buffer K2Buffer {
float k2[];
};
layout(std430, binding = 3) buffer K3Buffer {
float k3[];
};
layout(std430, binding = 4) buffer K4Buffer {
float k4[];
};
layout(std430, binding = 5) buffer ResultBuffer {
float result[];
};
uniform int n;
void main() {
uint index = gl_GlobalInvocationID.x;
if (index >= n) return;
// Load current state
vec4 current_state = vec4(
y[index * 4 + 0],
y[index * 4 + 1],
y[index * 4 + 2],
y[index * 4 + 3]
);
// Load all k values
vec4 k1_val = vec4(
k1[index * 4 + 0],
k1[index * 4 + 1],
k1[index * 4 + 2],
k1[index * 4 + 3]
);
vec4 k2_val = vec4(
k2[index * 4 + 0],
k2[index * 4 + 1],
k2[index * 4 + 2],
k2[index * 4 + 3]
);
vec4 k3_val = vec4(
k3[index * 4 + 0],
k3[index * 4 + 1],
k3[index * 4 + 2],
k3[index * 4 + 3]
);
vec4 k4_val = vec4(
k4[index * 4 + 0],
k4[index * 4 + 1],
k4[index * 4 + 2],
k4[index * 4 + 3]
);
// RK4 combination: y_new = y + (k1 + 2*k2 + 2*k3 + k4) / 6
vec4 weighted_sum = k1_val + 2.0 * k2_val + 2.0 * k3_val + k4_val;
vec4 new_state = current_state + weighted_sum / 6.0;
// Store result
result[index * 4 + 0] = new_state.x;
result[index * 4 + 1] = new_state.y;
result[index * 4 + 2] = new_state.z;
result[index * 4 + 3] = new_state.w;
}