[−][src]Trait autograd::op::Op
Operation trait. Tensor
wraps trait-object of this.
Implementing differentiable operations
Many of well-known ops are pre-defined in Graph
, but you can also
implement custom ops by hand.
use ndarray; use autograd as ag; use ag::Graph; type NdArray<T: ag::Float> = ndarray::Array<T, ndarray::IxDyn>; // Implements `Op` trait for `Sigmoid`. struct Sigmoid; impl<T: ag::Float> ag::op::Op<T> for Sigmoid { fn compute( &self, ctx: &mut ag::op::ComputeContext<T>, ) { let x: &ag::NdArrayView<_> = &ctx.input(0); // Use `ndarray::Array::mapv` for element-wise computation. let half = T::from(0.5).unwrap(); let y = x.mapv(move |a| ((a * half).tanh() * half) + half); ctx.append_output(y); } fn grad(&self, ctx: &mut ag::op::GradientContext<T>) { // Symbolic gradient of the input of Sigmoid let gy = ctx.output_grad(); let y = ctx.output(); let gx = gy * (y - ctx.graph().square(y)); ctx.append_input_grad(Some(gx)); } } use ag::tensor::Input; // Symbolic `sigmoid` function for end-user. fn sigmoid<'graph, F: ag::Float>(x: &ag::Tensor<'graph, F>, g: &'graph ag::Graph<F>) -> ag::Tensor<'graph, F> { ag::Tensor::builder() .set_inputs(&[Input::new(x)]) .build(g, Sigmoid) }
Required methods
fn compute(&self, ctx: &mut ComputeContext<F>)
Runs this op with ComputeContext
.
fn grad(&self, ctx: &mut GradientContext<F>)
Returns symbolic gradients for input nodes by use of output's gradients etc.