use ndarray::{Array1, ArrayView1, ArrayViewMut1};
use crate::hogwild::Hogwild;
use crate::idx::WordIdx;
use crate::loss::log_logistic_loss;
use crate::train_model::{NegativeSamples, TrainIterFrom, TrainModel, Trainer};
use crate::vec_simd::scaled_add;
#[derive(Clone)]
pub struct SGD<T> {
loss: Hogwild<f32>,
model: TrainModel<T>,
n_examples: Hogwild<usize>,
n_tokens_processed: Hogwild<usize>,
sgd_impl: NegativeSamplingSGD,
}
impl<T> SGD<T>
where
T: Trainer,
{
pub fn into_model(self) -> TrainModel<T> {
self.model
}
pub fn new(model: TrainModel<T>) -> Self {
let sgd_impl = NegativeSamplingSGD::new(model.config().negative_samples as usize);
SGD {
loss: Hogwild::default(),
model,
n_examples: Hogwild::default(),
n_tokens_processed: Hogwild::default(),
sgd_impl,
}
}
pub fn model(&self) -> &TrainModel<T> {
&self.model
}
pub fn n_tokens_processed(&self) -> usize {
*self.n_tokens_processed
}
pub fn train_loss(&self) -> f32 {
*self.loss / *self.n_examples as f32
}
pub fn update_sentence<'b, S>(&mut self, sentence: &S, lr: f32)
where
S: ?Sized,
T: TrainIterFrom<'b, S> + Trainer + NegativeSamples,
for<'a> &'a T::Focus: IntoIterator<Item = u64>,
T::Focus: WordIdx,
{
for (focus, contexts) in self.model.trainer().train_iter_from(sentence) {
let input_embed = self.model.mean_input_embedding(&focus);
for context in contexts {
*self.loss += self.sgd_impl.sgd_step(
&mut self.model,
(&focus).into_iter(),
input_embed.view(),
context,
lr,
);
*self.n_examples += 1;
}
*self.n_tokens_processed += 1;
}
}
}
#[derive(Clone)]
pub struct NegativeSamplingSGD {
negative_samples: usize,
}
impl NegativeSamplingSGD {
pub fn new(negative_samples: usize) -> Self {
NegativeSamplingSGD { negative_samples }
}
pub fn sgd_step<T>(
&mut self,
model: &mut TrainModel<T>,
input: impl IntoIterator<Item = u64>,
input_embed: ArrayView1<f32>,
output: usize,
lr: f32,
) -> f32
where
T: NegativeSamples,
{
let mut loss = 0.0;
let mut input_delta = Array1::zeros(input_embed.shape()[0]);
loss += self.update_output(
model,
input_embed.view(),
input_delta.view_mut(),
output,
true,
lr,
);
loss += self.negative_samples(model, input_embed, input_delta.view_mut(), output, lr);
for idx in input {
let input_embed = model.input_embedding_mut(idx as usize);
scaled_add(input_embed, input_delta.view(), 1.0);
}
loss
}
fn negative_samples<T>(
&mut self,
model: &mut TrainModel<T>,
input_embed: ArrayView1<f32>,
mut input_delta: ArrayViewMut1<f32>,
output: usize,
lr: f32,
) -> f32
where
T: NegativeSamples,
{
let mut loss = 0f32;
for _ in 0..self.negative_samples {
let negative = model.trainer().negative_sample(output);
loss += self.update_output(
model,
input_embed.view(),
input_delta.view_mut(),
negative,
false,
lr,
);
}
loss
}
fn update_output<T>(
&mut self,
model: &mut TrainModel<T>,
input_embed: ArrayView1<f32>,
input_delta: ArrayViewMut1<f32>,
output: usize,
label: bool,
lr: f32,
) -> f32 {
let (loss, part_gradient) =
log_logistic_loss(input_embed.view(), model.output_embedding(output), label);
scaled_add(
input_delta,
model.output_embedding(output),
lr * part_gradient,
);
scaled_add(
model.output_embedding_mut(output),
input_embed.view(),
lr * part_gradient,
);
loss
}
}