// Module: stdlib/nn/loss/kl_div.tern
// Purpose: Kullback-Leibler Divergence
// Author: RFI-IRFOS
// Ref: https://ternlang.com
// Measures difference between two trit probability distributions.
fn forward_kl(p: trit, q: trit) -> trit {
if p == tend { return tend; } // Missing true dist
if p == q { return affirm; } // 0 divergence
return reject; // Divergence exists
}
fn reverse_kl(p: trit, q: trit) -> trit {
return forward_kl(q, p);
}
fn symmetrized_kl(p: trit, q: trit) -> trit {
// Jensen-Shannon analog
let fwd: trit = forward_kl(p, q);
let rev: trit = reverse_kl(p, q);
if fwd == tend { return tend; }
if rev == tend { return tend; }
if fwd == affirm { return affirm; } // Approximated
return reject;
}
fn kl_divergence_trit(p_dist: trit[], q_dist: trit[]) -> trit {
let result: trit = tend;
match result {
affirm => { return affirm; }
tend => { return tend; }
reject => { return reject; }
}
}