use super::*;
pub struct LanguageModelOutput<B: BackendTrait> {
loss: Tensor<B, 1>,
}
impl<B: BackendTrait> LanguageModelOutput<B> {
pub fn new(loss: Tensor<B, 1>) -> Self {
Self { loss }
}
}
impl<B: BackendTrait> ItemLazy for LanguageModelOutput<B> {
type ItemSync = Self;
fn sync(self) -> Self::ItemSync {
self
}
}
impl<B: BackendTrait> Adaptor<LossInput<B>> for LanguageModelOutput<B> {
fn adapt(&self) -> LossInput<B> {
LossInput::new(self.loss.clone())
}
}
impl<B: BackendTrait> Adaptor<LossValue<B>> for LanguageModelOutput<B> {
fn adapt(&self) -> LossValue<B> {
LossValue::new(self.loss.clone())
}
}
#[derive(Clone)]
pub struct LossValue<B: BackendTrait> {
value: Tensor<B, 1>,
}
impl<B: BackendTrait> LossValue<B> {
pub fn new(value: Tensor<B, 1>) -> Self {
Self { value }
}
}
#[derive(Clone)]
pub struct LanguageModelTrainItem<B: AutodiffBackend> {
loss: Tensor<B, 1>,
}
impl<B: AutodiffBackend> LanguageModelTrainItem<B> {
pub fn new(loss: Tensor<B, 1>) -> Self {
Self {
loss: loss.detach(),
}
}
}
impl<B: AutodiffBackend> ItemLazy for LanguageModelTrainItem<B> {
type ItemSync = LanguageModelOutput<B::InnerBackend>;
fn sync(self) -> Self::ItemSync {
LanguageModelOutput::new(self.loss.detach().inner())
}
}
impl<B: BackendTrait> ScalarValue<B> for LossValue<B> {
fn value(&self) -> Tensor<B, 1> {
self.value.clone()
}
}