use crate::{
arith::ArithAlgebra,
array::ArrayAlgebra,
core::{CoreAlgebra, HasDims},
error::{check_equal_dimensions, Error, Result},
graph::Value,
matrix::MatrixAlgebra,
net::{HasGradientId, HasGradientReader, Net, WeightOps},
Graph1, Number,
};
use serde::{Deserialize, Serialize};
pub trait SingleOutputNet<Data, Algebra>: Net<Algebra>
where
Algebra: HasGradientReader + CoreAlgebra<Data, Value = Self::Output>,
{
fn add_square_loss(self) -> SquareLoss<Self, Data>
where
Self: Sized,
{
SquareLoss(self, std::marker::PhantomData)
}
}
impl<Data, Algebra, N> SingleOutputNet<Data, Algebra> for N
where
N: Net<Algebra>,
Algebra: HasGradientReader + CoreAlgebra<Data, Value = Self::Output>,
{
}
pub trait DiffNet<T>: Net<Graph1, Output = Value<T>>
where
T: Number,
Self::Weights: WeightOps<T>,
{
fn apply_gradient_step(&mut self, lambda: T, batch: Vec<Self::Input>) -> Result<T> {
let mut delta: Option<Self::Weights> = None;
let mut cumulated_output: Option<T> = None;
for example in batch {
let mut g = Graph1::new();
let (output, info) = self.eval_with_gradient_info(&mut g, example)?;
match &mut cumulated_output {
opt @ None => *opt = Some(*output.data()),
Some(val) => *val += *output.data(),
}
let store = g.evaluate_gradients_once(output.gid()?, T::one())?;
let gradients = self.read_weight_gradients(info, &store)?;
match &mut delta {
opt @ None => *opt = Some(gradients.scale(lambda)),
Some(val) => val.add_assign(gradients.scale(lambda))?,
}
}
if let Some(delta) = delta {
self.update_weights(delta)?;
}
cumulated_output.ok_or_else(|| Error::empty(func_name!()))
}
}
impl<N, T> DiffNet<T> for N
where
T: Number,
N: Net<Graph1, Output = Value<T>>,
N::Weights: WeightOps<T>,
{
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SquareLoss<N, Data>(N, std::marker::PhantomData<Data>);
impl<Data, Algebra, N> Net<Algebra> for SquareLoss<N, Data>
where
Algebra: HasGradientReader
+ CoreAlgebra<Data, Value = N::Output>
+ ArrayAlgebra<N::Output>
+ ArithAlgebra<N::Output>
+ MatrixAlgebra<N::Output>,
N: Net<Algebra>,
Data: HasDims,
N::Output: HasDims<Dims = Data::Dims>,
Data::Dims: Clone + PartialEq + std::fmt::Debug,
{
type Input = (N::Input, Data);
type Output = <Algebra as ArrayAlgebra<N::Output>>::Scalar;
type Weights = N::Weights;
type GradientInfo = N::GradientInfo;
fn eval_with_gradient_info(
&self,
graph: &mut Algebra,
input: Self::Input,
) -> Result<(Self::Output, Self::GradientInfo)> {
let (output, info) = self.0.eval_with_gradient_info(graph, input.0)?;
check_equal_dimensions(
"eval_with_gradient_info",
&[&output.dims(), &input.1.dims()],
)?;
let target = graph.constant(input.1);
let delta = graph.sub(&target, &output)?;
let loss = graph.norm2(&delta);
Ok((loss, info))
}
fn get_weights(&self) -> Self::Weights {
self.0.get_weights()
}
fn set_weights(&mut self, weights: Self::Weights) -> Result<()> {
self.0.set_weights(weights)
}
fn update_weights(&mut self, delta: Self::Weights) -> Result<()> {
self.0.update_weights(delta)
}
fn read_weight_gradients(
&self,
info: Self::GradientInfo,
store: &Algebra::GradientReader,
) -> Result<Self::Weights> {
self.0.read_weight_gradients(info, store)
}
}