scirs2-autograd 0.3.2

Automatic differentiation module for SciRS2 (scirs2-autograd)
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
use crate::Float;
use crate::{op, NdArray, NdArrayView};
use std::marker::PhantomData;

pub(crate) struct MapOp<T: Float> {
    pub(crate) phantom: PhantomData<T>,
    pub(crate) f: fn(NdArrayView<T>) -> NdArray<T>,
}

impl<F: Float> op::Op<F> for MapOp<F> {
    fn compute(&self, ctx: &mut op::ComputeContext<F>) -> Result<(), op::OpError> {
        let f = self.f;
        let x = ctx.input(0);
        ctx.append_output(f(x));
        Ok(())
    }
    fn grad<'a>(&self, _ctx: &mut op::GradientContext<'a, 'a, F>) {}
}