use crate::train::prelude::*;
impl<B: AutodiffBackend> TrainStep<SequenceBatch<B>, LanguageModelTrainItem<B>> for BDH<B> {
fn step(&self, batch: SequenceBatch<B>) -> TrainOutput<LanguageModelTrainItem<B>> {
let logits = if fast_train_enabled() {
self.forward_fast(batch.inputs)
} else {
self.forward(batch.inputs)
};
let loss = language_model_loss::<B>(logits, batch.targets);
let grads = loss.backward();
TrainOutput::new(self, grads, LanguageModelTrainItem::new(loss))
}
}
impl<B: BackendTrait> ValidStep<SequenceBatch<B>, LanguageModelOutput<B>> for BDH<B> {
fn step(&self, batch: SequenceBatch<B>) -> LanguageModelOutput<B> {
let logits = if fast_train_enabled() {
self.forward_fast(batch.inputs)
} else {
self.forward(batch.inputs)
};
let loss = language_model_loss::<B>(logits, batch.targets);
LanguageModelOutput::new(loss)
}
}