use crate::metric::{AccuracyInput, Adaptor, HammingScoreInput, LossInput};
use burn_core::tensor::backend::Backend;
use burn_core::tensor::{Int, Tensor};
#[derive(new)]
pub struct ClassificationOutput<B: Backend> {
pub loss: Tensor<B, 1>,
pub output: Tensor<B, 2>,
pub targets: Tensor<B, 1, Int>,
}
impl<B: Backend> Adaptor<AccuracyInput<B>> for ClassificationOutput<B> {
fn adapt(&self) -> AccuracyInput<B> {
AccuracyInput::new(self.output.clone(), self.targets.clone())
}
}
impl<B: Backend> Adaptor<LossInput<B>> for ClassificationOutput<B> {
fn adapt(&self) -> LossInput<B> {
LossInput::new(self.loss.clone())
}
}
#[derive(new)]
pub struct MultiLabelClassificationOutput<B: Backend> {
pub loss: Tensor<B, 1>,
pub output: Tensor<B, 2>,
pub targets: Tensor<B, 2, Int>,
}
impl<B: Backend> Adaptor<HammingScoreInput<B>> for MultiLabelClassificationOutput<B> {
fn adapt(&self) -> HammingScoreInput<B> {
HammingScoreInput::new(self.output.clone(), self.targets.clone())
}
}
impl<B: Backend> Adaptor<LossInput<B>> for MultiLabelClassificationOutput<B> {
fn adapt(&self) -> LossInput<B> {
LossInput::new(self.loss.clone())
}
}