use crate::error::{NeuralError, Result};
use crate::losses::Loss;
use scirs2_core::ndarray::{Array, IxDyn, Zip};
use scirs2_core::numeric::{Float, NumAssign};
use scirs2_symbolic::eml::eval::{eval_real, EvalCtx};
use scirs2_symbolic::eml::grad;
use scirs2_symbolic::eml::op::LoweredOp;
use std::fmt::Debug;
use std::sync::Arc;
#[derive(Clone)]
pub struct SymbolicLoss {
op: Arc<LoweredOp>,
grad_pred_op: Arc<LoweredOp>,
}
impl SymbolicLoss {
pub fn new(op: Arc<LoweredOp>) -> Result<Self> {
let grad_pred_op = grad(op.as_ref(), 0);
Ok(Self {
op,
grad_pred_op: Arc::new(grad_pred_op),
})
}
}
impl Debug for SymbolicLoss {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SymbolicLoss").finish()
}
}
impl<F: Float + Debug + NumAssign> Loss<F> for SymbolicLoss {
fn forward(&self, predictions: &Array<F, IxDyn>, targets: &Array<F, IxDyn>) -> Result<F> {
if predictions.shape() != targets.shape() {
return Err(NeuralError::ShapeMismatch(format!(
"SymbolicLoss shape mismatch: predictions {:?} vs targets {:?}",
predictions.shape(),
targets.shape()
)));
}
let n = predictions.len();
if n == 0 {
return Err(NeuralError::InvalidArgument(
"empty prediction array in SymbolicLoss::forward".to_string(),
));
}
let mut sum = F::zero();
let mut error: Option<NeuralError> = None;
Zip::from(predictions).and(targets).for_each(|&p, &t| {
if error.is_some() {
return;
}
let pf = p.to_f64().unwrap_or(0.0);
let tf = t.to_f64().unwrap_or(0.0);
match eval_real(self.op.as_ref(), &EvalCtx::new(&[pf, tf])) {
Ok(v) => {
let vf = F::from(v).unwrap_or_else(F::nan);
sum += vf;
}
Err(e) => {
error = Some(NeuralError::ComputationError(e.to_string()));
}
}
});
if let Some(e) = error {
return Err(e);
}
let n_f = F::from(n).ok_or_else(|| {
NeuralError::ComputationError("could not convert array length to float".to_string())
})?;
Ok(sum / n_f)
}
fn backward(
&self,
predictions: &Array<F, IxDyn>,
targets: &Array<F, IxDyn>,
) -> Result<Array<F, IxDyn>> {
if predictions.shape() != targets.shape() {
return Err(NeuralError::ShapeMismatch(format!(
"SymbolicLoss shape mismatch: predictions {:?} vs targets {:?}",
predictions.shape(),
targets.shape()
)));
}
let n = predictions.len();
if n == 0 {
return Err(NeuralError::InvalidArgument(
"empty prediction array in SymbolicLoss::backward".to_string(),
));
}
let n_f = F::from(n).ok_or_else(|| {
NeuralError::ComputationError("could not convert array length to float".to_string())
})?;
let mut grad_out = Array::zeros(predictions.raw_dim());
let mut error: Option<NeuralError> = None;
Zip::from(&mut grad_out)
.and(predictions)
.and(targets)
.for_each(|out, &p, &t| {
if error.is_some() {
return;
}
let pf = p.to_f64().unwrap_or(0.0);
let tf = t.to_f64().unwrap_or(0.0);
match eval_real(self.grad_pred_op.as_ref(), &EvalCtx::new(&[pf, tf])) {
Ok(dldp) => {
let dldp_f = F::from(dldp).unwrap_or_else(F::nan);
*out = dldp_f / n_f;
}
Err(e) => {
error = Some(NeuralError::ComputationError(e.to_string()));
}
}
});
if let Some(e) = error {
return Err(e);
}
Ok(grad_out)
}
}