use std::cell::OnceCell;
use smol_str::format_smolstr;
use crate::{
array::Array,
dtype::Dtype,
error::{
DtypeMismatchPayload, EmptyInputPayload, Error, LengthMismatchPayload, OutOfRangePayload,
RankMismatchPayload, Result, ShapePairMismatchPayload, UnsupportedDtypePayload,
},
ops::fast::metal_kernel::{KernelTemplateArg, MetalKernel, MetalKernelApplyConfig},
transforms::custom_vjp,
};
const fn kl_forward_msl_source() -> &'static str {
r#"
constexpr int M = 4;
constexpr int block = 1024 * M;
constexpr int full_blocks = V / block;
constexpr int extra = V - full_blocks * block;
threadgroup float shared[32 * 2];
uint out_idx = threadgroup_position_in_grid.y;
uint simd_lane_id = thread_index_in_simdgroup;
uint simd_group_id = simdgroup_index_in_threadgroup;
logits_q += out_idx * V;
logits_p += out_idx * V;
out += out_idx;
float lse_q_minus_p;
float lse_p;
{
float max_q = -1e30;
float max_p = -1e30;
float sum_exp_q = 0;
float sum_exp_p = 0;
int offset = thread_index_in_threadgroup * M;
for (int i = 0; i < full_blocks; i++) {
// Read and update q and p
float vals_q[M];
float vals_p[M];
for (int j=0; j<M; j++) {
vals_q[j] = logits_q[offset + j];
vals_p[j] = logits_p[offset + j];
}
float prev_max_q = max_q;
float prev_max_p = max_p;
for (int j=0; j<M; j++) {
max_q = max(max_q, vals_q[j]);
max_p = max(max_p, vals_p[j]);
}
sum_exp_q *= metal::fast::exp(prev_max_q - max_q);
sum_exp_p *= metal::fast::exp(prev_max_p - max_p);
for (int j=0; j<M; j++) {
sum_exp_q += metal::fast::exp(vals_q[j] - max_q);
sum_exp_p += metal::fast::exp(vals_p[j] - max_p);
}
// Move to the next block
offset += block;
}
if (extra > 0) {
// Read and update q and p
float vals_q[M];
float vals_p[M];
for (int j=0; j < M; j++) {
vals_q[j] = (offset + j < V) ? logits_q[offset + j] : -1e30;
vals_p[j] = (offset + j < V) ? logits_p[offset + j] : -1e30;
}
float prev_max_q = max_q;
float prev_max_p = max_p;
for (int j=0; j<M; j++) {
max_q = max(max_q, vals_q[j]);
max_p = max(max_p, vals_p[j]);
}
sum_exp_q *= metal::fast::exp(prev_max_q - max_q);
sum_exp_p *= metal::fast::exp(prev_max_p - max_p);
for (int j=0; j<M; j++) {
sum_exp_q += metal::fast::exp(vals_q[j] - max_q);
sum_exp_p += metal::fast::exp(vals_p[j] - max_p);
}
}
// Share the maxs across the threadgroup
float prev_max_q = max_q;
float prev_max_p = max_p;
max_q = simd_max(max_q);
max_p = simd_max(max_p);
if (simd_lane_id == 0) {
shared[simd_group_id * 2 + 0] = max_q;
shared[simd_group_id * 2 + 1] = max_p;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
max_q = shared[simd_lane_id * 2 + 0];
max_p = shared[simd_lane_id * 2 + 1];
max_q = simd_max(max_q);
max_p = simd_max(max_p);
// Share the sum_exp across the threadgroup
sum_exp_q *= metal::fast::exp(prev_max_q - max_q);
sum_exp_p *= metal::fast::exp(prev_max_p - max_p);
sum_exp_q = simd_sum(sum_exp_q);
sum_exp_p = simd_sum(sum_exp_p);
if (simd_lane_id == 0) {
shared[simd_group_id * 2 + 0] = sum_exp_q;
shared[simd_group_id * 2 + 1] = sum_exp_p;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sum_exp_q = shared[simd_lane_id * 2 + 0];
sum_exp_p = shared[simd_lane_id * 2 + 1];
sum_exp_q = simd_sum(sum_exp_q);
sum_exp_p = simd_sum(sum_exp_p);
lse_p = max_p + metal::fast::log(sum_exp_p);
lse_q_minus_p = max_q + metal::fast::log(sum_exp_q) - lse_p;
}
threadgroup_barrier(mem_flags::mem_none);
{
float kl = 0;
int offset = thread_index_in_threadgroup * M;
for (int i = 0; i < full_blocks; i++) {
// Read and add to the kl
float vals_q[M];
float vals_p[M];
for (int j=0; j<M; j++) {
vals_q[j] = logits_q[offset + j];
vals_p[j] = logits_p[offset + j];
}
for (int j=0; j<M; j++) {
kl += metal::fast::exp(vals_p[j] - lse_p) * (vals_p[j] - vals_q[j] + lse_q_minus_p);
}
// Move to the next block
offset += block;
}
if (extra > 0) {
float vals_q[M];
float vals_p[M];
for (int j=0; j<M; j++) {
vals_q[j] = (offset + j < V) ? logits_q[offset + j] : -1e30;
vals_p[j] = (offset + j < V) ? logits_p[offset + j] : -1e30;
}
for (int j=0; j<M; j++) {
kl += metal::fast::exp(vals_p[j] - lse_p) * (vals_p[j] - vals_q[j] + lse_q_minus_p);
}
}
// Add the kl across the threadgroup
kl = simd_sum(kl);
if (simd_lane_id == 0) {
shared[simd_group_id] = kl;
}
threadgroup_barrier(mem_flags::mem_none);
kl = shared[simd_lane_id];
kl = simd_sum(kl);
if (thread_index_in_threadgroup == 0) {
out[0] = static_cast<T>(kl);
}
}
"#
}
const fn kl_backward_msl_source() -> &'static str {
r#"
constexpr int M = 4;
constexpr int block = 1024 * M;
constexpr int full_blocks = V / block;
constexpr int extra = V - full_blocks * block;
threadgroup float shared[32 * 2];
uint out_idx = threadgroup_position_in_grid.y;
uint simd_lane_id = thread_index_in_simdgroup;
uint simd_group_id = simdgroup_index_in_threadgroup;
logits_q += out_idx * V;
logits_p += out_idx * V;
out += out_idx * V;
cotan += out_idx;
float lse_q;
float lse_p;
{
float max_q = -1e30;
float max_p = -1e30;
float sum_exp_q = 0;
float sum_exp_p = 0;
int offset = thread_index_in_threadgroup * M;
for (int i = 0; i < full_blocks; i++) {
// Read and update q and p
float vals_q[M];
float vals_p[M];
for (int j=0; j<M; j++) {
vals_q[j] = logits_q[offset + j];
vals_p[j] = logits_p[offset + j];
}
float prev_max_q = max_q;
float prev_max_p = max_p;
for (int j=0; j<M; j++) {
max_q = max(max_q, vals_q[j]);
max_p = max(max_p, vals_p[j]);
}
sum_exp_q *= metal::fast::exp(prev_max_q - max_q);
sum_exp_p *= metal::fast::exp(prev_max_p - max_p);
for (int j=0; j<M; j++) {
sum_exp_q += metal::fast::exp(vals_q[j] - max_q);
sum_exp_p += metal::fast::exp(vals_p[j] - max_p);
}
// Move to the next block
offset += block;
}
if (extra > 0) {
// Read and update q and p
float vals_q[M];
float vals_p[M];
for (int j=0; j < M; j++) {
vals_q[j] = (offset + j < V) ? logits_q[offset + j] : -1e30;
vals_p[j] = (offset + j < V) ? logits_p[offset + j] : -1e30;
}
float prev_max_q = max_q;
float prev_max_p = max_p;
for (int j=0; j<M; j++) {
max_q = max(max_q, vals_q[j]);
max_p = max(max_p, vals_p[j]);
}
sum_exp_q *= metal::fast::exp(prev_max_q - max_q);
sum_exp_p *= metal::fast::exp(prev_max_p - max_p);
for (int j=0; j<M; j++) {
sum_exp_q += metal::fast::exp(vals_q[j] - max_q);
sum_exp_p += metal::fast::exp(vals_p[j] - max_p);
}
}
// Share the maxs across the threadgroup
float prev_max_q = max_q;
float prev_max_p = max_p;
max_q = simd_max(max_q);
max_p = simd_max(max_p);
if (simd_lane_id == 0) {
shared[simd_group_id * 2 + 0] = max_q;
shared[simd_group_id * 2 + 1] = max_p;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
max_q = shared[simd_lane_id * 2 + 0];
max_p = shared[simd_lane_id * 2 + 1];
max_q = simd_max(max_q);
max_p = simd_max(max_p);
// Share the sum_exp across the threadgroup
sum_exp_q *= metal::fast::exp(prev_max_q - max_q);
sum_exp_p *= metal::fast::exp(prev_max_p - max_p);
sum_exp_q = simd_sum(sum_exp_q);
sum_exp_p = simd_sum(sum_exp_p);
if (simd_lane_id == 0) {
shared[simd_group_id * 2 + 0] = sum_exp_q;
shared[simd_group_id * 2 + 1] = sum_exp_p;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sum_exp_q = shared[simd_lane_id * 2 + 0];
sum_exp_p = shared[simd_lane_id * 2 + 1];
sum_exp_q = simd_sum(sum_exp_q);
sum_exp_p = simd_sum(sum_exp_p);
lse_p = max_p + metal::fast::log(sum_exp_p);
lse_q = max_q + metal::fast::log(sum_exp_q);
}
threadgroup_barrier(mem_flags::mem_none);
{
float kl = 0;
float c = cotan[0];
int offset = thread_index_in_threadgroup * M;
for (int i = 0; i < full_blocks; i++) {
// Read and add to the kl
float vals_q[M];
float vals_p[M];
for (int j=0; j<M; j++) {
vals_q[j] = logits_q[offset + j];
vals_p[j] = logits_p[offset + j];
}
for (int j=0; j<M; j++) {
out[offset + j] = static_cast<T>(
c * (metal::fast::exp(vals_q[j] - lse_q) - metal::fast::exp(vals_p[j] - lse_p)));
}
// Move to the next block
offset += block;
}
if (extra > 0) {
float vals_q[M];
float vals_p[M];
for (int j=0; j<M; j++) {
vals_q[j] = (offset + j < V) ? logits_q[offset + j] : -1e30;
vals_p[j] = (offset + j < V) ? logits_p[offset + j] : -1e30;
}
for (int j=0; j<M; j++) {
if (offset + j < V) {
out[offset + j] = static_cast<T>(
c * (metal::fast::exp(vals_q[j] - lse_q) - metal::fast::exp(vals_p[j] - lse_p)));
}
}
}
}
"#
}
const fn js_forward_msl_source() -> &'static str {
r#"
constexpr int M = 4;
constexpr int block = 1024 * M;
constexpr int full_blocks = V / block;
constexpr int extra = V - full_blocks * block;
threadgroup float shared[32 * 2];
uint out_idx = threadgroup_position_in_grid.y;
uint simd_lane_id = thread_index_in_simdgroup;
uint simd_group_id = simdgroup_index_in_threadgroup;
logits_q += out_idx * V;
logits_p += out_idx * V;
out += out_idx;
out_kl_q += out_idx;
float lse_p;
float lse_q;
{
float max_q = -1e30;
float max_p = -1e30;
float sum_exp_q = 0;
float sum_exp_p = 0;
int offset = thread_index_in_threadgroup * M;
for (int i = 0; i < full_blocks; i++) {
// Read and update q and p
float vals_q[M];
float vals_p[M];
for (int j=0; j<M; j++) {
vals_q[j] = logits_q[offset + j];
vals_p[j] = logits_p[offset + j];
}
float prev_max_q = max_q;
float prev_max_p = max_p;
for (int j=0; j<M; j++) {
max_q = max(max_q, vals_q[j]);
max_p = max(max_p, vals_p[j]);
}
sum_exp_q *= metal::fast::exp(prev_max_q - max_q);
sum_exp_p *= metal::fast::exp(prev_max_p - max_p);
for (int j=0; j<M; j++) {
sum_exp_q += metal::fast::exp(vals_q[j] - max_q);
sum_exp_p += metal::fast::exp(vals_p[j] - max_p);
}
// Move to the next block
offset += block;
}
if (extra > 0) {
// Read and update q and p
float vals_q[M];
float vals_p[M];
for (int j=0; j < M; j++) {
vals_q[j] = (offset + j < V) ? logits_q[offset + j] : -1e30;
vals_p[j] = (offset + j < V) ? logits_p[offset + j] : -1e30;
}
float prev_max_q = max_q;
float prev_max_p = max_p;
for (int j=0; j<M; j++) {
max_q = max(max_q, vals_q[j]);
max_p = max(max_p, vals_p[j]);
}
sum_exp_q *= metal::fast::exp(prev_max_q - max_q);
sum_exp_p *= metal::fast::exp(prev_max_p - max_p);
for (int j=0; j<M; j++) {
sum_exp_q += metal::fast::exp(vals_q[j] - max_q);
sum_exp_p += metal::fast::exp(vals_p[j] - max_p);
}
}
// Share the maxs across the threadgroup
float prev_max_q = max_q;
float prev_max_p = max_p;
max_q = simd_max(max_q);
max_p = simd_max(max_p);
if (simd_lane_id == 0) {
shared[simd_group_id * 2 + 0] = max_q;
shared[simd_group_id * 2 + 1] = max_p;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
max_q = shared[simd_lane_id * 2 + 0];
max_p = shared[simd_lane_id * 2 + 1];
max_q = simd_max(max_q);
max_p = simd_max(max_p);
// Share the sum_exp across the threadgroup
sum_exp_q *= metal::fast::exp(prev_max_q - max_q);
sum_exp_p *= metal::fast::exp(prev_max_p - max_p);
sum_exp_q = simd_sum(sum_exp_q);
sum_exp_p = simd_sum(sum_exp_p);
if (simd_lane_id == 0) {
shared[simd_group_id * 2 + 0] = sum_exp_q;
shared[simd_group_id * 2 + 1] = sum_exp_p;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sum_exp_q = shared[simd_lane_id * 2 + 0];
sum_exp_p = shared[simd_lane_id * 2 + 1];
sum_exp_q = simd_sum(sum_exp_q);
sum_exp_p = simd_sum(sum_exp_p);
lse_p = max_p + metal::fast::log(sum_exp_p);
lse_q = max_q + metal::fast::log(sum_exp_q);
}
threadgroup_barrier(mem_flags::mem_none);
{
float kl_p = 0;
float kl_q = 0;
const float logtwo = metal::fast::log(static_cast<float>(2));
int offset = thread_index_in_threadgroup * M;
for (int i = 0; i < full_blocks; i++) {
// Read and add to the kl_p and kl_q
float vals_q[M];
float vals_p[M];
for (int j=0; j<M; j++) {
vals_q[j] = logits_q[offset + j];
vals_p[j] = logits_p[offset + j];
}
for (int j=0; j<M; j++) {
float logp_j = vals_p[j] - lse_p;
float logq_j = vals_q[j] - lse_q;
float p_j = metal::fast::exp(logp_j);
float q_j = metal::fast::exp(logq_j);
kl_p += p_j * (logtwo - metal::fast::log(1 + metal::fast::exp(logq_j - logp_j)));
kl_q += q_j * (logtwo - metal::fast::log(1 + metal::fast::exp(logp_j - logq_j)));
}
// Move to the next block
offset += block;
}
if (extra > 0) {
float vals_q[M];
float vals_p[M];
for (int j=0; j<M; j++) {
vals_q[j] = (offset + j < V) ? logits_q[offset + j] : -1e30;
vals_p[j] = (offset + j < V) ? logits_p[offset + j] : -1e30;
}
for (int j=0; j<M; j++) {
float logp_j = vals_p[j] - lse_p;
float logq_j = vals_q[j] - lse_q;
float p_j = metal::fast::exp(logp_j);
float q_j = metal::fast::exp(logq_j);
kl_p += p_j * (logtwo - metal::fast::log(1 + metal::fast::exp(logq_j - logp_j)));
kl_q += q_j * (logtwo - metal::fast::log(1 + metal::fast::exp(logp_j - logq_j)));
}
}
// Add the kl_p and kl_q across the threadgroup
kl_p = simd_sum(kl_p);
kl_q = simd_sum(kl_q);
if (simd_lane_id == 0) {
shared[simd_group_id * 2 + 0] = kl_p;
shared[simd_group_id * 2 + 1] = kl_q;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
kl_p = shared[simd_lane_id * 2 + 0];
kl_q = shared[simd_lane_id * 2 + 1];
kl_p = simd_sum(kl_p);
kl_q = simd_sum(kl_q);
if (thread_index_in_threadgroup == 0) {
out[0] = static_cast<T>(0.5 * kl_p + 0.5 * kl_q);
out_kl_q[0] = static_cast<T>(kl_q);
}
}
"#
}
const fn js_backward_msl_source() -> &'static str {
r#"
constexpr int M = 4;
constexpr int block = 1024 * M;
constexpr int full_blocks = V / block;
constexpr int extra = V - full_blocks * block;
threadgroup float shared[32 * 2];
uint out_idx = threadgroup_position_in_grid.y;
uint simd_lane_id = thread_index_in_simdgroup;
uint simd_group_id = simdgroup_index_in_threadgroup;
logits_q += out_idx * V;
logits_p += out_idx * V;
out_q += out_idx * V;
cotan += out_idx;
output_kl_q += out_idx;
float lse_q;
float lse_p;
{
float max_q = -1e30;
float max_p = -1e30;
float sum_exp_q = 0;
float sum_exp_p = 0;
int offset = thread_index_in_threadgroup * M;
for (int i = 0; i < full_blocks; i++) {
// Read and update q and p
float vals_q[M];
float vals_p[M];
for (int j=0; j<M; j++) {
vals_q[j] = logits_q[offset + j];
vals_p[j] = logits_p[offset + j];
}
float prev_max_q = max_q;
float prev_max_p = max_p;
for (int j=0; j<M; j++) {
max_q = max(max_q, vals_q[j]);
max_p = max(max_p, vals_p[j]);
}
sum_exp_q *= metal::fast::exp(prev_max_q - max_q);
sum_exp_p *= metal::fast::exp(prev_max_p - max_p);
for (int j=0; j<M; j++) {
sum_exp_q += metal::fast::exp(vals_q[j] - max_q);
sum_exp_p += metal::fast::exp(vals_p[j] - max_p);
}
// Move to the next block
offset += block;
}
if (extra > 0) {
// Read and update q and p
float vals_q[M];
float vals_p[M];
for (int j=0; j < M; j++) {
vals_q[j] = (offset + j < V) ? logits_q[offset + j] : -1e30;
vals_p[j] = (offset + j < V) ? logits_p[offset + j] : -1e30;
}
float prev_max_q = max_q;
float prev_max_p = max_p;
for (int j=0; j<M; j++) {
max_q = max(max_q, vals_q[j]);
max_p = max(max_p, vals_p[j]);
}
sum_exp_q *= metal::fast::exp(prev_max_q - max_q);
sum_exp_p *= metal::fast::exp(prev_max_p - max_p);
for (int j=0; j<M; j++) {
sum_exp_q += metal::fast::exp(vals_q[j] - max_q);
sum_exp_p += metal::fast::exp(vals_p[j] - max_p);
}
}
// Share the maxs across the threadgroup
float prev_max_q = max_q;
float prev_max_p = max_p;
max_q = simd_max(max_q);
max_p = simd_max(max_p);
if (simd_lane_id == 0) {
shared[simd_group_id * 2 + 0] = max_q;
shared[simd_group_id * 2 + 1] = max_p;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
max_q = shared[simd_lane_id * 2 + 0];
max_p = shared[simd_lane_id * 2 + 1];
max_q = simd_max(max_q);
max_p = simd_max(max_p);
// Share the sum_exp across the threadgroup
sum_exp_q *= metal::fast::exp(prev_max_q - max_q);
sum_exp_p *= metal::fast::exp(prev_max_p - max_p);
sum_exp_q = simd_sum(sum_exp_q);
sum_exp_p = simd_sum(sum_exp_p);
if (simd_lane_id == 0) {
shared[simd_group_id * 2 + 0] = sum_exp_q;
shared[simd_group_id * 2 + 1] = sum_exp_p;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sum_exp_q = shared[simd_lane_id * 2 + 0];
sum_exp_p = shared[simd_lane_id * 2 + 1];
sum_exp_q = simd_sum(sum_exp_q);
sum_exp_p = simd_sum(sum_exp_p);
lse_p = max_p + metal::fast::log(sum_exp_p);
lse_q = max_q + metal::fast::log(sum_exp_q);
}
threadgroup_barrier(mem_flags::mem_none);
{
float c = cotan[0];
const float logtwo = metal::fast::log(static_cast<float>(2));
float kl_q = output_kl_q[0];
int offset = thread_index_in_threadgroup * M;
for (int i = 0; i < full_blocks; i++) {
// Read and compute vjp for logits_q
float vals_q[M];
float vals_p[M];
for (int j=0; j<M; j++) {
vals_q[j] = logits_q[offset + j];
vals_p[j] = logits_p[offset + j];
}
for (int j=0; j<M; j++) {
float logp_j = vals_p[j] - lse_p;
float logq_j = vals_q[j] - lse_q;
float q_j = metal::fast::exp(logq_j);
out_q[offset + j] = static_cast<T>(
c * 0.5 * q_j * (logtwo - metal::fast::log(1 + metal::fast::exp(logp_j - logq_j)) - kl_q)
);
}
// Move to the next block
offset += block;
}
if (extra > 0) {
float vals_q[M];
float vals_p[M];
for (int j=0; j<M; j++) {
vals_q[j] = (offset + j < V) ? logits_q[offset + j] : -1e30;
vals_p[j] = (offset + j < V) ? logits_p[offset + j] : -1e30;
}
for (int j=0; j<M; j++) {
if (offset + j < V) {
float logp_j = vals_p[j] - lse_p;
float logq_j = vals_q[j] - lse_q;
float q_j = metal::fast::exp(logq_j);
out_q[offset + j] = static_cast<T>(
c * 0.5 * q_j * (logtwo - metal::fast::log(1 + metal::fast::exp(logp_j - logq_j)) - kl_q)
);
}
}
}
}
"#
}
thread_local! {
static KL_FORWARD: OnceCell<MetalKernel> = const { OnceCell::new() };
static KL_BACKWARD: OnceCell<MetalKernel> = const { OnceCell::new() };
static JS_FORWARD: OnceCell<MetalKernel> = const { OnceCell::new() };
static JS_BACKWARD: OnceCell<MetalKernel> = const { OnceCell::new() };
}
fn with_kernel<F, R>(
cell: &'static std::thread::LocalKey<OnceCell<MetalKernel>>,
build: impl FnOnce() -> Result<MetalKernel>,
f: F,
) -> Result<R>
where
F: FnOnce(&MetalKernel) -> Result<R>,
{
let already_initialized = cell.with(|c| c.get().is_some());
if !already_initialized {
let kernel = build()?;
let _ = cell.with(|c| c.set(kernel));
}
cell.with(|c| {
let kernel = c
.get()
.expect("kernel must be initialized by the preceding set");
f(kernel)
})
}
fn build_kl_forward_kernel() -> Result<MetalKernel> {
MetalKernel::new(
"kl_forward",
&["logits_q", "logits_p"],
&["out"],
kl_forward_msl_source(),
"",
true,
false,
)
}
fn build_kl_backward_kernel() -> Result<MetalKernel> {
MetalKernel::new(
"kl_backward",
&["logits_q", "logits_p", "cotan"],
&["out"],
kl_backward_msl_source(),
"",
true,
false,
)
}
fn build_js_forward_kernel() -> Result<MetalKernel> {
MetalKernel::new(
"js_forward",
&["logits_q", "logits_p"],
&["out", "out_kl_q"],
js_forward_msl_source(),
"",
true,
false,
)
}
fn build_js_backward_kernel() -> Result<MetalKernel> {
MetalKernel::new(
"js_backward",
&["logits_q", "logits_p", "cotan", "output_kl_q"],
&["out_q"],
js_backward_msl_source(),
"",
true,
false,
)
}
fn n_outs_of(logits: &Array) -> Result<i32> {
let shape = logits.shape();
let v = shape.last().copied().ok_or_else(|| {
Error::RankMismatch(RankMismatchPayload::new(
"mlxrs::lm::tuner::losses: logits must have rank >= 1",
0,
Vec::new(),
))
})?;
if v == 0 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"mlxrs::lm::tuner::losses: logits last dimension",
"must be > 0",
"0",
)));
}
let total: usize = shape.iter().product();
let n_outs = total / v;
i32::try_from(n_outs).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
"mlxrs::lm::tuner::losses: n_outs",
"must fit in i32",
format_smolstr!("{n_outs}"),
))
})
}
fn vocab_of(logits: &Array) -> Result<i32> {
let shape = logits.shape();
let v = shape.last().copied().ok_or_else(|| {
Error::RankMismatch(RankMismatchPayload::new(
"mlxrs::lm::tuner::losses: logits must have rank >= 1",
0,
Vec::new(),
))
})?;
i32::try_from(v).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
"mlxrs::lm::tuner::losses: vocab size",
"must fit in i32",
format_smolstr!("{v}"),
))
})
}
fn leading_shape_i32(logits: &Array) -> Result<Vec<i32>> {
let shape = logits.shape();
if shape.is_empty() {
return Err(Error::RankMismatch(RankMismatchPayload::new(
"mlxrs::lm::tuner::losses: logits must have rank >= 1",
0,
Vec::new(),
)));
}
shape[..shape.len() - 1]
.iter()
.map(|&d| {
i32::try_from(d).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
"mlxrs::lm::tuner::losses: shape dim",
"must fit in i32",
format_smolstr!("{d}"),
))
})
})
.collect()
}
fn full_shape_i32(logits: &Array) -> Result<Vec<i32>> {
logits
.shape()
.iter()
.map(|&d| {
i32::try_from(d).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
"mlxrs::lm::tuner::losses: shape dim",
"must fit in i32",
format_smolstr!("{d}"),
))
})
})
.collect()
}
#[allow(clippy::too_many_arguments)]
fn validate_inputs(
logits_q: &Array,
logits_p: &Array,
ctx_q: &'static str,
ctx_p: &'static str,
ctx_pair: &'static str,
ctx_last: &'static str,
ctx_dim: &'static str,
ctx_dtype: &'static str,
) -> Result<()> {
let sq = logits_q.shape();
let sp = logits_p.shape();
if logits_q.ndim() < 2 {
return Err(Error::RankMismatch(RankMismatchPayload::new(
ctx_q,
logits_q.ndim() as u32,
sq.to_vec(),
)));
}
if logits_p.ndim() < 2 {
return Err(Error::RankMismatch(RankMismatchPayload::new(
ctx_p,
logits_p.ndim() as u32,
sp.to_vec(),
)));
}
if sq != sp {
return Err(Error::ShapePairMismatch(ShapePairMismatchPayload::new(
ctx_pair,
sq.to_vec(),
sp.to_vec(),
)));
}
let last = *sq.last().expect("rank>=2 guaranteed by checks above");
if last == 0 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
ctx_last,
"must be > 0",
"0",
)));
}
for &d in sq.iter() {
if i32::try_from(d).is_err() {
return Err(Error::OutOfRange(OutOfRangePayload::new(
ctx_dim,
"must fit in i32",
format_smolstr!("{d}"),
)));
}
}
let dq = logits_q.dtype()?;
let dp = logits_p.dtype()?;
if dq != dp {
return Err(Error::DtypeMismatch(DtypeMismatchPayload::new(dq, dp)));
}
match dq {
Dtype::F32 | Dtype::F16 | Dtype::BF16 => {}
_ => {
return Err(Error::UnsupportedDtype(UnsupportedDtypePayload::new(
ctx_dtype,
dq,
&[Dtype::F32, Dtype::F16, Dtype::BF16],
)));
}
}
Ok(())
}
fn template_for(dtype: Dtype, vocab: i32) -> Vec<(String, KernelTemplateArg)> {
vec![
("T".to_string(), KernelTemplateArg::Dtype(dtype)),
("V".to_string(), KernelTemplateArg::Int(vocab)),
]
}
fn kl_forward_apply(logits_q: &Array, logits_p: &Array) -> Result<Array> {
let dtype = logits_q.dtype()?;
let vocab = vocab_of(logits_q)?;
let n_outs = n_outs_of(logits_q)?;
let out_shape = leading_shape_i32(logits_q)?;
let cfg = MetalKernelApplyConfig::new(
[1024, n_outs as u32, 1],
[1024, 1, 1],
vec![out_shape],
vec![dtype],
)?
.with_template(template_for(dtype, vocab));
with_kernel(&KL_FORWARD, build_kl_forward_kernel, |kernel| {
let mut outputs = kernel.apply(&[logits_q, logits_p], &cfg)?;
Ok(outputs.swap_remove(0))
})
}
fn kl_backward_apply(logits_q: &Array, logits_p: &Array, cotangent: &Array) -> Result<Array> {
let dtype = logits_q.dtype()?;
let vocab = vocab_of(logits_q)?;
let cot_size = i32::try_from(cotangent.size()).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
"mlxrs::lm::tuner::losses::kl_backward: cotangent size",
"must fit in i32",
format_smolstr!("{}", cotangent.size()),
))
})?;
let out_shape = full_shape_i32(logits_q)?;
let cfg = MetalKernelApplyConfig::new(
[1024, cot_size as u32, 1],
[1024, 1, 1],
vec![out_shape],
vec![dtype],
)?
.with_template(template_for(dtype, vocab));
with_kernel(&KL_BACKWARD, build_kl_backward_kernel, |kernel| {
let mut outputs = kernel.apply(&[logits_q, logits_p, cotangent], &cfg)?;
Ok(outputs.swap_remove(0))
})
}
fn js_forward_apply(logits_q: &Array, logits_p: &Array) -> Result<(Array, Array)> {
let dtype = logits_q.dtype()?;
let vocab = vocab_of(logits_q)?;
let n_outs = n_outs_of(logits_q)?;
let leading = leading_shape_i32(logits_q)?;
let cfg = MetalKernelApplyConfig::new(
[1024, n_outs as u32, 1],
[1024, 1, 1],
vec![leading.clone(), leading],
vec![dtype, dtype],
)?
.with_template(template_for(dtype, vocab));
with_kernel(&JS_FORWARD, build_js_forward_kernel, |kernel| {
let mut outputs = kernel.apply(&[logits_q, logits_p], &cfg)?;
if outputs.len() != 2 {
return Err(Error::LengthMismatch(LengthMismatchPayload::new(
"mlxrs::lm::tuner::losses::js_forward: kernel outputs",
2,
outputs.len(),
)));
}
let kl_q = outputs.swap_remove(1);
let loss = outputs.swap_remove(0);
Ok((loss, kl_q))
})
}
fn js_backward_apply(
logits_q: &Array,
logits_p: &Array,
cotan: &Array,
kl_q: &Array,
) -> Result<Array> {
let dtype = logits_q.dtype()?;
let vocab = vocab_of(logits_q)?;
let cot_size = i32::try_from(cotan.size()).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
"mlxrs::lm::tuner::losses::js_backward: cotan size",
"must fit in i32",
format_smolstr!("{}", cotan.size()),
))
})?;
let out_shape = full_shape_i32(logits_q)?;
let cfg = MetalKernelApplyConfig::new(
[1024, cot_size as u32, 1],
[1024, 1, 1],
vec![out_shape],
vec![dtype],
)?
.with_template(template_for(dtype, vocab));
with_kernel(&JS_BACKWARD, build_js_backward_kernel, |kernel| {
let mut outputs = kernel.apply(&[logits_q, logits_p, cotan, kl_q], &cfg)?;
Ok(outputs.swap_remove(0))
})
}
pub fn kl_div_loss(logits_q: &Array, logits_p: &Array) -> Result<Array> {
validate_inputs(
logits_q,
logits_p,
"kl_div_loss: logits_q rank (must be >= 2; reshape rank-1 [V] to [1, V])",
"kl_div_loss: logits_p rank (must be >= 2; reshape rank-1 [V] to [1, V])",
"kl_div_loss: logits_q vs logits_p shape",
"kl_div_loss: logits last dimension",
"kl_div_loss: shape dim",
"kl_div_loss: logits dtype (cast with .astype(Dtype::F32) before calling)",
)?;
let wrapped = custom_vjp(
|inputs: &[Array]| -> Result<Vec<Array>> {
let out = kl_forward_apply(&inputs[0], &inputs[1])?;
Ok(vec![out])
},
|primals: &[Array], cotangents: &[Array], _outputs: &[Array]| -> Result<Vec<Array>> {
let logits_q = &primals[0];
let logits_p = &primals[1];
let cotangent = &cotangents[0];
let dq = kl_backward_apply(logits_q, logits_p, cotangent)?;
let dp = crate::ops::misc::zeros_like(logits_p)?;
Ok(vec![dq, dp])
},
)?;
let inputs = [logits_q.try_clone()?, logits_p.try_clone()?];
let mut outputs = wrapped(&inputs)?;
if outputs.is_empty() {
return Err(Error::EmptyInput(EmptyInputPayload::new(
"kl_div_loss: forward closure output",
)));
}
Ok(outputs.swap_remove(0))
}
pub fn js_div_loss(logits_q: &Array, logits_p: &Array) -> Result<Array> {
validate_inputs(
logits_q,
logits_p,
"js_div_loss: logits_q rank (must be >= 2; reshape rank-1 [V] to [1, V])",
"js_div_loss: logits_p rank (must be >= 2; reshape rank-1 [V] to [1, V])",
"js_div_loss: logits_q vs logits_p shape",
"js_div_loss: logits last dimension",
"js_div_loss: shape dim",
"js_div_loss: logits dtype (cast with .astype(Dtype::F32) before calling)",
)?;
let wrapped = custom_vjp(
|inputs: &[Array]| -> Result<Vec<Array>> {
let (loss, kl_q) = js_forward_apply(&inputs[0], &inputs[1])?;
Ok(vec![loss, kl_q])
},
|primals: &[Array], cotangents: &[Array], outputs: &[Array]| -> Result<Vec<Array>> {
let logits_q = &primals[0];
let logits_p = &primals[1];
let cotan = &cotangents[0];
let kl_q = &outputs[1];
let dq = js_backward_apply(logits_q, logits_p, cotan, kl_q)?;
let dp = crate::ops::misc::zeros_like(logits_p)?;
Ok(vec![dq, dp])
},
)?;
let inputs = [logits_q.try_clone()?, logits_p.try_clone()?];
let mut outputs = wrapped(&inputs)?;
if outputs.is_empty() {
return Err(Error::EmptyInput(EmptyInputPayload::new(
"js_div_loss: forward closure output",
)));
}
Ok(outputs.swap_remove(0))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn kl_forward_msl_source_contains_signature_landmarks() {
let s = kl_forward_msl_source();
assert!(s.contains("constexpr int M = 4;"));
assert!(s.contains("constexpr int block = 1024 * M;"));
assert!(s.contains("threadgroup float shared[32 * 2];"));
assert!(s.contains("logits_q += out_idx * V;"));
assert!(s.contains("logits_p += out_idx * V;"));
assert!(s.contains("out += out_idx;"));
assert!(s.contains("out[0] = static_cast<T>(kl);"));
assert!(s.contains("lse_q_minus_p"));
assert!(!s.contains("kl_p +="));
}
#[test]
fn kl_backward_msl_source_contains_signature_landmarks() {
let s = kl_backward_msl_source();
assert!(s.contains("constexpr int M = 4;"));
assert!(s.contains("out += out_idx * V;"));
assert!(s.contains("cotan += out_idx;"));
assert!(
s.contains("c * (metal::fast::exp(vals_q[j] - lse_q) - metal::fast::exp(vals_p[j] - lse_p))")
);
assert!(!s.contains("output_kl_q"));
}
#[test]
fn js_forward_msl_source_contains_signature_landmarks() {
let s = js_forward_msl_source();
assert!(s.contains("constexpr int M = 4;"));
assert!(s.contains("out += out_idx;"));
assert!(s.contains("out_kl_q += out_idx;"));
assert!(s.contains("logtwo - metal::fast::log(1 + metal::fast::exp(logq_j - logp_j))"));
assert!(s.contains("logtwo - metal::fast::log(1 + metal::fast::exp(logp_j - logq_j))"));
assert!(s.contains("out[0] = static_cast<T>(0.5 * kl_p + 0.5 * kl_q);"));
assert!(s.contains("out_kl_q[0] = static_cast<T>(kl_q);"));
}
#[test]
fn js_backward_msl_source_contains_signature_landmarks() {
let s = js_backward_msl_source();
assert!(s.contains("constexpr int M = 4;"));
assert!(s.contains("output_kl_q += out_idx;"));
assert!(s.contains("float kl_q = output_kl_q[0];"));
assert!(s.contains(
"c * 0.5 * q_j * (logtwo - metal::fast::log(1 + metal::fast::exp(logp_j - logq_j)) - kl_q)"
));
}
#[test]
fn validate_inputs_rejects_shape_mismatch() {
let a = Array::ones::<f32>(&[2, 4]).unwrap();
let b = Array::ones::<f32>(&[2, 8]).unwrap();
let err = validate_inputs(
&a,
&b,
"kl_div_loss: logits_q rank",
"kl_div_loss: logits_p rank",
"kl_div_loss: logits_q vs logits_p shape",
"kl_div_loss: logits last dimension",
"kl_div_loss: shape dim",
"kl_div_loss: logits dtype",
)
.unwrap_err();
match err {
Error::ShapePairMismatch(p) => {
assert_eq!(p.expected(), &[2, 4]);
assert_eq!(p.actual(), &[2, 8]);
}
other => panic!("expected ShapePairMismatch, got: {other:?}"),
}
}
#[test]
fn validate_inputs_rejects_dtype_mismatch() {
let a = Array::ones::<f32>(&[2, 4]).unwrap();
let b = Array::ones::<half::f16>(&[2, 4]).unwrap();
let err = validate_inputs(
&a,
&b,
"kl_div_loss: logits_q rank",
"kl_div_loss: logits_p rank",
"kl_div_loss: logits_q vs logits_p shape",
"kl_div_loss: logits last dimension",
"kl_div_loss: shape dim",
"kl_div_loss: logits dtype",
)
.unwrap_err();
match err {
Error::DtypeMismatch(p) => {
assert_eq!(p.expected(), Dtype::F32);
assert_eq!(p.got(), Dtype::F16);
}
other => panic!("expected DtypeMismatch, got: {other:?}"),
}
}
#[test]
fn kl_div_loss_rejects_shape_mismatch() {
let a = Array::ones::<f32>(&[2, 4]).unwrap();
let b = Array::ones::<f32>(&[2, 8]).unwrap();
let err = kl_div_loss(&a, &b).unwrap_err();
match err {
Error::ShapePairMismatch(p) => {
assert!(p.context().contains("kl_div_loss"), "got: {p:?}");
}
other => panic!("expected ShapePairMismatch, got: {other:?}"),
}
}
#[test]
fn js_div_loss_rejects_shape_mismatch() {
let a = Array::ones::<f32>(&[2, 4]).unwrap();
let b = Array::ones::<f32>(&[2, 8]).unwrap();
let err = js_div_loss(&a, &b).unwrap_err();
match err {
Error::ShapePairMismatch(p) => {
assert!(p.context().contains("js_div_loss"), "got: {p:?}");
}
other => panic!("expected ShapePairMismatch, got: {other:?}"),
}
}
#[test]
fn n_outs_of_computes_total_over_last_dim() {
let a = Array::ones::<f32>(&[3, 5, 7]).unwrap();
assert_eq!(n_outs_of(&a).unwrap(), 15);
}
#[test]
fn n_outs_of_rejects_rank_0() {
let a = Array::full::<f32>(&[0i32; 0], 1.0).unwrap();
let err = n_outs_of(&a).unwrap_err();
match err {
Error::RankMismatch(p) => {
assert!(p.context().contains("rank"));
assert_eq!(p.actual(), 0);
}
other => panic!("expected RankMismatch, got: {other:?}"),
}
}
#[test]
fn n_outs_of_rejects_zero_last_dim() {
let a = Array::ones::<f32>(&[2, 0]).unwrap();
let err = n_outs_of(&a).unwrap_err();
match err {
Error::OutOfRange(p) => {
assert_eq!(
p.context(),
"mlxrs::lm::tuner::losses: logits last dimension"
);
assert_eq!(p.requirement(), "must be > 0");
assert_eq!(p.value(), "0");
}
other => panic!("expected OutOfRange for zero last dim, got: {other:?}"),
}
}
#[test]
fn vocab_of_returns_last_dim_as_i32() {
let a = Array::ones::<f32>(&[2, 4, 128]).unwrap();
assert_eq!(vocab_of(&a).unwrap(), 128);
}
#[test]
fn leading_shape_strips_last_axis() {
let a = Array::ones::<f32>(&[3, 5, 7]).unwrap();
assert_eq!(leading_shape_i32(&a).unwrap(), vec![3i32, 5]);
}
#[test]
fn full_shape_preserves_all_dims() {
let a = Array::ones::<f32>(&[3, 5, 7]).unwrap();
assert_eq!(full_shape_i32(&a).unwrap(), vec![3i32, 5, 7]);
}
#[test]
fn template_for_emits_t_and_v_in_canonical_order() {
let t = template_for(Dtype::F32, 128);
assert_eq!(t.len(), 2);
assert_eq!(t[0].0, "T");
assert_eq!(t[0].1, KernelTemplateArg::Dtype(Dtype::F32));
assert_eq!(t[1].0, "V");
assert_eq!(t[1].1, KernelTemplateArg::Int(128));
}
#[test]
fn with_kernel_lazily_builds_then_reuses_on_second_call() {
let arity1 = with_kernel(&KL_FORWARD, build_kl_forward_kernel, |kernel| {
Ok(kernel.output_arity())
})
.expect("kl_forward kernel construction should not need a Metal device");
assert_eq!(arity1, 1, "kl_forward declares a single `out`");
let arity2 = with_kernel(&KL_FORWARD, build_kl_forward_kernel, |kernel| {
Ok(kernel.output_arity())
})
.expect("cached kl_forward kernel re-fetch should succeed");
assert_eq!(arity2, arity1);
}
#[test]
fn with_kernel_builds_all_four_kernels() {
let kl_bwd = with_kernel(&KL_BACKWARD, build_kl_backward_kernel, |kernel| {
Ok((kernel.output_arity(), kernel.output_names_slice().to_vec()))
})
.expect("kl_backward construction");
assert_eq!(kl_bwd.0, 1);
assert_eq!(kl_bwd.1, vec!["out".to_string()]);
let js_fwd = with_kernel(&JS_FORWARD, build_js_forward_kernel, |kernel| {
Ok((kernel.output_arity(), kernel.output_names_slice().to_vec()))
})
.expect("js_forward construction");
assert_eq!(js_fwd.0, 2);
assert_eq!(js_fwd.1, vec!["out".to_string(), "out_kl_q".to_string()]);
let js_bwd = with_kernel(&JS_BACKWARD, build_js_backward_kernel, |kernel| {
Ok((kernel.output_arity(), kernel.output_names_slice().to_vec()))
})
.expect("js_backward construction");
assert_eq!(js_bwd.0, 1);
assert_eq!(js_bwd.1, vec!["out_q".to_string()]);
}
#[test]
fn vocab_of_rejects_rank_0() {
let a = Array::full::<f32>(&[0i32; 0], 1.0).unwrap();
let err = vocab_of(&a).unwrap_err();
match err {
Error::RankMismatch(p) => {
assert!(p.context().contains("rank"), "got: {p:?}");
assert_eq!(p.actual(), 0);
}
other => panic!("expected RankMismatch, got: {other:?}"),
}
}
#[test]
fn leading_shape_i32_rejects_rank_0() {
let a = Array::full::<f32>(&[0i32; 0], 1.0).unwrap();
let err = leading_shape_i32(&a).unwrap_err();
match err {
Error::RankMismatch(p) => {
assert!(p.context().contains("rank"), "got: {p:?}");
assert_eq!(p.actual(), 0);
}
other => panic!("expected RankMismatch, got: {other:?}"),
}
}
#[test]
fn validate_inputs_rejects_rank_1_logits_p() {
let q = Array::ones::<f32>(&[1, 4]).unwrap();
let p = Array::ones::<f32>(&[4]).unwrap();
let err = validate_inputs(
&q,
&p,
"ctx_q rank",
"ctx_p rank",
"ctx_pair",
"ctx_last",
"ctx_dim",
"ctx_dtype",
)
.unwrap_err();
match err {
Error::RankMismatch(p) => {
assert_eq!(p.context(), "ctx_p rank");
assert_eq!(p.actual(), 1);
assert_eq!(p.actual_shape(), &[4]);
}
other => panic!("expected RankMismatch on logits_p, got: {other:?}"),
}
}
#[test]
fn validate_inputs_rejects_zero_last_dim() {
let q = Array::ones::<f32>(&[1, 0]).unwrap();
let p = Array::ones::<f32>(&[1, 0]).unwrap();
let err = validate_inputs(
&q,
&p,
"ctx_q rank",
"ctx_p rank",
"ctx_pair",
"ctx_last",
"ctx_dim",
"ctx_dtype",
)
.unwrap_err();
match err {
Error::OutOfRange(p) => {
assert_eq!(p.context(), "ctx_last");
assert_eq!(p.requirement(), "must be > 0");
assert_eq!(p.value(), "0");
}
other => panic!("expected OutOfRange for zero last dim, got: {other:?}"),
}
}
#[test]
fn validate_inputs_accepts_f16_and_bf16() {
let q16 = Array::ones::<half::f16>(&[1, 4]).unwrap();
let p16 = Array::ones::<half::f16>(&[1, 4]).unwrap();
validate_inputs(
&q16,
&p16,
"ctx_q",
"ctx_p",
"ctx_pair",
"ctx_last",
"ctx_dim",
"ctx_dtype",
)
.expect("f16 logits should validate");
let qbf = Array::ones::<half::bf16>(&[1, 4]).unwrap();
let pbf = Array::ones::<half::bf16>(&[1, 4]).unwrap();
validate_inputs(
&qbf,
&pbf,
"ctx_q",
"ctx_p",
"ctx_pair",
"ctx_last",
"ctx_dim",
"ctx_dtype",
)
.expect("bf16 logits should validate");
}
#[test]
fn validate_inputs_rejects_unsupported_dtype() {
let q = Array::ones::<i32>(&[1, 4]).unwrap();
let p = Array::ones::<i32>(&[1, 4]).unwrap();
let err = validate_inputs(
&q,
&p,
"ctx_q",
"ctx_p",
"ctx_pair",
"ctx_last",
"ctx_dim",
"ctx_dtype: cast",
)
.unwrap_err();
match err {
Error::UnsupportedDtype(p) => {
assert_eq!(p.context(), "ctx_dtype: cast");
assert_eq!(p.dtype(), Dtype::I32);
assert_eq!(p.supported(), &[Dtype::F32, Dtype::F16, Dtype::BF16]);
}
other => panic!("expected UnsupportedDtype, got: {other:?}"),
}
}
#[test]
fn kl_div_loss_rejects_rank_1_input() {
let a = Array::ones::<f32>(&[4]).unwrap();
let b = Array::ones::<f32>(&[4]).unwrap();
let err = kl_div_loss(&a, &b).unwrap_err();
match err {
Error::RankMismatch(p) => {
assert!(p.context().contains("kl_div_loss"), "got: {p:?}");
assert_eq!(p.actual(), 1);
}
other => panic!("expected RankMismatch, got: {other:?}"),
}
}
#[test]
fn js_div_loss_rejects_unsupported_dtype() {
let a = Array::ones::<i32>(&[1, 4]).unwrap();
let b = Array::ones::<i32>(&[1, 4]).unwrap();
let err = js_div_loss(&a, &b).unwrap_err();
match err {
Error::UnsupportedDtype(p) => {
assert!(p.context().contains("js_div_loss"), "got: {p:?}");
assert_eq!(p.dtype(), Dtype::I32);
}
other => panic!("expected UnsupportedDtype, got: {other:?}"),
}
}
#[cfg(target_os = "macos")]
const LN3: f32 = 1.098_612_3;
#[cfg(target_os = "macos")]
#[test]
#[ignore = "requires a Metal-capable GPU"]
fn kl_div_loss_matches_closed_form() {
let q = Array::from_slice::<f32>(&[0.0, LN3], &[1, 2]).unwrap();
let p = Array::from_slice::<f32>(&[0.0, 0.0], &[1, 2]).unwrap();
let mut loss = kl_div_loss(&q, &p).unwrap();
assert_eq!(loss.shape(), vec![1]);
let buf: Vec<f32> = loss.to_vec().unwrap();
let expected = 0.5 * (2.0 * std::f32::consts::LN_2 - LN3);
assert!(
(buf[0] - expected).abs() < 1e-4,
"KL loss = {}, expected ≈ {expected}",
buf[0]
);
}
#[cfg(target_os = "macos")]
#[test]
#[ignore = "requires a Metal-capable GPU"]
fn kl_div_loss_identity_is_zero() {
let x = Array::from_slice::<f32>(&[0.3, -1.2, 2.5, 0.0], &[1, 4]).unwrap();
let y = x.try_clone().unwrap();
let mut loss = kl_div_loss(&x, &y).unwrap();
let buf: Vec<f32> = loss.to_vec().unwrap();
assert!(buf[0].abs() < 1e-5, "KL(P||P) = {}, expected ≈ 0", buf[0]);
}
#[cfg(target_os = "macos")]
#[test]
#[ignore = "requires a Metal-capable GPU"]
fn kl_div_loss_batched_rows_are_independent() {
let q = Array::from_slice::<f32>(&[0.0, LN3, 0.7, 0.7], &[2, 2]).unwrap();
let p = Array::from_slice::<f32>(&[0.0, 0.0, 0.7, 0.7], &[2, 2]).unwrap();
let mut loss = kl_div_loss(&q, &p).unwrap();
assert_eq!(loss.shape(), vec![2]);
let buf: Vec<f32> = loss.to_vec().unwrap();
let row0 = 0.5 * (2.0 * std::f32::consts::LN_2 - LN3);
assert!(
(buf[0] - row0).abs() < 1e-4,
"row0 = {}, expected {row0}",
buf[0]
);
assert!(buf[1].abs() < 1e-5, "row1 = {}, expected ≈ 0", buf[1]);
}
#[cfg(target_os = "macos")]
#[test]
#[ignore = "requires a Metal-capable GPU"]
fn kl_div_loss_grad_equals_q_minus_p() {
let q = Array::from_slice::<f32>(&[0.0, LN3], &[1, 2]).unwrap();
let p = Array::from_slice::<f32>(&[0.0, 0.0], &[1, 2]).unwrap();
let g = crate::transforms::grad(
|xs: &[Array]| Ok(vec![kl_div_loss(&xs[0], &xs[1])?.sum(false)?]),
&[0, 1],
)
.unwrap();
let mut grads = g(&[q, p]).unwrap();
assert_eq!(grads.len(), 2);
let dq: Vec<f32> = grads[0].to_vec().unwrap();
let dp: Vec<f32> = grads[1].to_vec().unwrap();
assert!(
(dq[0] - (-0.25)).abs() < 1e-4,
"dq[0] = {}, expected -0.25",
dq[0]
);
assert!(
(dq[1] - 0.25).abs() < 1e-4,
"dq[1] = {}, expected 0.25",
dq[1]
);
assert!(
dp[0].abs() < 1e-6 && dp[1].abs() < 1e-6,
"dp = {dp:?}, expected zeros"
);
}
#[cfg(target_os = "macos")]
#[test]
#[ignore = "requires a Metal-capable GPU"]
fn js_div_loss_matches_closed_form() {
let q = Array::from_slice::<f32>(&[0.0, LN3], &[1, 2]).unwrap();
let p = Array::from_slice::<f32>(&[0.0, 0.0], &[1, 2]).unwrap();
let mut loss = js_div_loss(&q, &p).unwrap();
assert_eq!(loss.shape(), vec![1]);
let buf: Vec<f32> = loss.to_vec().unwrap();
let expected = 0.033_822_08_f32;
assert!(
(buf[0] - expected).abs() < 1e-4,
"JS loss = {}, expected ≈ {expected}",
buf[0]
);
}
#[cfg(target_os = "macos")]
#[test]
#[ignore = "requires a Metal-capable GPU"]
fn js_div_loss_identity_is_zero() {
let x = Array::from_slice::<f32>(&[0.3, -1.2, 2.5, 0.0], &[1, 4]).unwrap();
let y = x.try_clone().unwrap();
let mut loss = js_div_loss(&x, &y).unwrap();
let buf: Vec<f32> = loss.to_vec().unwrap();
assert!(buf[0].abs() < 1e-5, "JS(P,P) = {}, expected ≈ 0", buf[0]);
}
#[cfg(target_os = "macos")]
#[test]
#[ignore = "requires a Metal-capable GPU"]
fn js_div_loss_grad_runs_and_sums_to_zero() {
let q = Array::from_slice::<f32>(&[0.0, LN3], &[1, 2]).unwrap();
let p = Array::from_slice::<f32>(&[0.0, 0.0], &[1, 2]).unwrap();
let g = crate::transforms::grad(
|xs: &[Array]| Ok(vec![js_div_loss(&xs[0], &xs[1])?.sum(false)?]),
&[0, 1],
)
.unwrap();
let mut grads = g(&[q, p]).unwrap();
assert_eq!(grads.len(), 2);
let dq: Vec<f32> = grads[0].to_vec().unwrap();
let dp: Vec<f32> = grads[1].to_vec().unwrap();
assert_eq!(dq.len(), 2);
assert!(
(dq[0] + dq[1]).abs() < 1e-4,
"Σ dq = {}, expected ≈ 0",
dq[0] + dq[1]
);
assert!(
dq[0].abs() > 1e-4 || dq[1].abs() > 1e-4,
"dq unexpectedly ≈ 0: {dq:?}"
);
assert!(
dp[0].abs() < 1e-6 && dp[1].abs() < 1e-6,
"dp = {dp:?}, expected zeros"
);
}
}