finalfrontier/loss.rs
1use ndarray::ArrayView1;
2
3use crate::util;
4use crate::vec_simd::dot;
5
6/// Absolute activations to round in logistic regression.
7///
8/// Since the logistic function is asymptotic, there is always (a small)
9/// gradient for larger activations. As a result, optimization of logistic
10/// regression will not converge without e.g. regularization. In the
11/// training of embeddings, this has the result of ever-increasing weights
12/// (amplified by the optimization two vectors).
13///
14/// A simpler solution than regularization is to round the output of the
15/// logistic function to 0 (negative activation) or 1 (positive activiation)
16/// for large activations, to kill gradient.
17///
18/// This constant controls at what activation the logistic function should
19/// round.
20const LOGISTIC_ROUND_ACTIVATION: f32 = 10.0;
21
22/// Return the loss and gradient of the co-occurence classification.
23///
24/// This function returns the negative log likelihood and gradient of
25/// a training instance using the probability function *P(1|x) =
26/// σ(u·v)*. `u` and `v` are word embeddings and `label` is the
27/// target label, where a label of `1` means that the words co-occur
28/// and a label of `0` that they do not.
29///
30/// This model is very similar to logistic regression, except that we
31/// optimize both u and v.
32///
33/// The loss is as follows (y is used as the label):
34///
35/// log(P(y|x)) =
36/// y log(P(1|x)) + (1-y) log(P(0|x)) =
37/// y log(P(1|x)) + (1-y) log(1 - P(1|x)) =
38/// y log(σ(u·v)) + (1-y) log(1 - σ(u·v)) =
39/// y log(σ(u·v)) + (1-y) log(σ(-u·v))
40///
41/// We can simplify the first term:
42///
43/// y log(σ(u·v)) =
44/// y log(1/(1+e^{-u·v})) =
45/// -y log(1+e^{-u·v})
46///
47/// Then we find the derivative with respect to v_1:
48///
49/// ∂/∂v_1 -y log(1+e^{-u·v}) =
50/// -y σ(u·v) ∂/∂v_1(1+e^{-u·v}) =
51/// -y σ(u·v) e^{-u·v} -u_1 =
52/// y σ(-u·v) u_1 =
53/// y (1 - σ(u·v)) u_1 =
54/// (y - yσ(u·v)) u_1
55///
56/// Iff y = 1, then:
57///
58/// 1 - σ(u·v)
59///
60/// For the second term above, we also find the derivative:
61///
62/// ∂/∂v_1 -(1 - y) log(1+e^{u·v}) =
63/// -(1 - y) σ(-u·v) ∂/∂v_1(1+e^{u·v}) =
64/// -(1 - y) σ(-u·v) e^{u·v} ∂/∂v_1 u·v=
65/// -(1 - y) σ(-u·v) e^{u·v} u_1 =
66/// -(1 - y) σ(u·v) u_1 =
67/// (-σ(u·v) + yσ(u·v)) u_1
68///
69/// When y = 0 then:
70///
71/// -σ(u·v)u_1
72///
73/// Combining both, the partial derivative of v_1 is: y - σ(u·v)u_1
74///
75/// We return y - σ(u·v) as the gradient, so that the caller can compute
76/// the gradient for all components of u and v.
77pub fn log_logistic_loss(u: ArrayView1<f32>, v: ArrayView1<f32>, label: bool) -> (f32, f32) {
78 let dp = dot(u, v);
79 let lf = logistic_function(dp);
80 let grad = (label as usize) as f32 - lf;
81 let loss = if label {
82 -util::safe_ln(lf)
83 } else {
84 -util::safe_ln(1.0 - lf)
85 };
86
87 (loss, grad)
88}
89
90/// Compute the logistic function.
91///
92/// **σ(a) = 1 / (1 + e^{-a})**
93fn logistic_function(a: f32) -> f32 {
94 if a > LOGISTIC_ROUND_ACTIVATION {
95 1.0
96 } else if a < -LOGISTIC_ROUND_ACTIVATION {
97 0.0
98 } else {
99 1.0 / (1.0 + (-a).exp())
100 }
101}
102
103#[cfg(test)]
104mod tests {
105 use ndarray::Array1;
106
107 use crate::util::{all_close, close};
108
109 use super::{log_logistic_loss, logistic_function};
110
111 #[test]
112 fn logistic_function_test() {
113 let activations = &[
114 -11.0, -5.0, -4.0, -3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 11.0,
115 ];
116 let outputs: Vec<_> = activations.iter().map(|&a| logistic_function(a)).collect();
117 assert!(all_close(
118 &[
119 0.0, 0.00669, 0.01799, 0.04743, 0.11920, 0.26894, 0.5, 0.73106, 0.88080, 0.95257,
120 0.982014, 0.99331, 1.0
121 ],
122 outputs.as_slice(),
123 1e-5
124 ));
125 }
126
127 #[test]
128 fn log_logistic_loss_test() {
129 let a = Array1::from_shape_vec((6,), vec![1., 1., 1., 0., 0., 0.]).unwrap();
130 let a_orth = Array1::from_shape_vec((6,), vec![0., 0., 0., 1., 1., 1.]).unwrap();
131 let a_opp = Array1::from_shape_vec((6,), vec![-1., -1., -1., 0., 0., 0.]).unwrap();
132
133 let (loss, gradient) = log_logistic_loss(a.view(), a_orth.view(), true);
134 assert!(close(loss, 0.69312, 1e-5));
135 assert!(close(gradient, 0.5, 1e-5));
136
137 let (loss, gradient) = log_logistic_loss(a.view(), a_orth.view(), false);
138 assert!(close(loss, 0.69312, 1e-5));
139 assert!(close(gradient, -0.5, 1e-5));
140
141 let (loss, gradient) = log_logistic_loss(a.view(), a.view(), true);
142 assert!(close(loss, 0.04858, 1e-5));
143 assert!(close(gradient, 0.04742, 1e-5));
144
145 let (loss, gradient) = log_logistic_loss(a.view(), a_opp.view(), false);
146 assert!(close(loss, 0.04858, 1e-5));
147 assert!(close(gradient, -0.04743, 1e-5));
148
149 let (loss, gradient) = log_logistic_loss(a.view(), a_opp.view(), true);
150 assert!(close(loss, 3.04838, 1e-5));
151 assert!(close(gradient, 0.95257, 1e-5));
152 }
153}