use crate::TrainResult;
use scirs2_core::ndarray::{Array, ArrayView, Ix2};
use std::fmt::Debug;
use super::functions::Loss;
#[derive(Debug, Clone)]
pub struct DiceLoss {
pub smooth: f64,
}
#[derive(Debug, Clone)]
pub struct ContrastiveLoss {
pub margin: f64,
}
#[derive(Debug, Clone)]
pub struct RuleSatisfactionLoss {
pub temperature: f64,
}
#[derive(Debug, Clone, Default)]
pub struct BCEWithLogitsLoss;
#[derive(Debug, Clone)]
pub struct LossConfig {
pub supervised_weight: f64,
pub constraint_weight: f64,
pub rule_weight: f64,
pub temperature: f64,
}
#[derive(Debug, Clone)]
pub struct TverskyLoss {
pub alpha: f64,
pub beta: f64,
pub smooth: f64,
}
#[derive(Debug, Clone, Default)]
pub struct MseLoss;
#[derive(Debug, Clone)]
pub struct TripletLoss {
pub margin: f64,
}
#[derive(Debug, Clone)]
pub struct FocalLoss {
pub alpha: f64,
pub gamma: f64,
pub epsilon: f64,
}
#[derive(Debug, Clone)]
pub struct HingeLoss {
pub margin: f64,
}
#[derive(Debug, Clone)]
pub struct PolyLoss {
pub epsilon: f64,
pub poly_coeff: f64,
}
impl PolyLoss {
pub fn new(poly_coeff: f64) -> Self {
Self {
epsilon: 1e-10,
poly_coeff,
}
}
}
#[derive(Debug)]
pub struct LogicalLoss {
pub config: LossConfig,
pub supervised_loss: Box<dyn Loss>,
pub rule_losses: Vec<Box<dyn Loss>>,
pub constraint_losses: Vec<Box<dyn Loss>>,
}
impl LogicalLoss {
pub fn new(
config: LossConfig,
supervised_loss: Box<dyn Loss>,
rule_losses: Vec<Box<dyn Loss>>,
constraint_losses: Vec<Box<dyn Loss>>,
) -> Self {
Self {
config,
supervised_loss,
rule_losses,
constraint_losses,
}
}
pub fn compute_total(
&self,
predictions: &ArrayView<f64, Ix2>,
targets: &ArrayView<f64, Ix2>,
rule_values: &[ArrayView<f64, Ix2>],
constraint_values: &[ArrayView<f64, Ix2>],
) -> TrainResult<f64> {
let mut total = 0.0;
let supervised = self.supervised_loss.compute(predictions, targets)?;
total += self.config.supervised_weight * supervised;
if !rule_values.is_empty() && !self.rule_losses.is_empty() {
let expected_true = Array::ones((rule_values[0].nrows(), rule_values[0].ncols()));
let expected_true_view = expected_true.view();
for (rule_val, rule_loss) in rule_values.iter().zip(self.rule_losses.iter()) {
let rule_loss_val = rule_loss.compute(rule_val, &expected_true_view)?;
total += self.config.rule_weight * rule_loss_val;
}
}
if !constraint_values.is_empty() && !self.constraint_losses.is_empty() {
let expected_zero =
Array::zeros((constraint_values[0].nrows(), constraint_values[0].ncols()));
let expected_zero_view = expected_zero.view();
for (constraint_val, constraint_loss) in
constraint_values.iter().zip(self.constraint_losses.iter())
{
let constraint_loss_val =
constraint_loss.compute(constraint_val, &expected_zero_view)?;
total += self.config.constraint_weight * constraint_loss_val;
}
}
Ok(total)
}
}
#[derive(Debug, Clone)]
pub struct ConstraintViolationLoss {
pub penalty_weight: f64,
}
#[derive(Debug, Clone)]
pub struct CrossEntropyLoss {
pub epsilon: f64,
}
#[derive(Debug, Clone)]
pub struct KLDivergenceLoss {
pub epsilon: f64,
}
#[derive(Debug, Clone)]
pub struct HuberLoss {
pub delta: f64,
}