use crate::ndarray_ext::NdArray;
use crate::op::OpError;
use crate::Float;
pub struct Variable;
impl<T: Float> crate::op::Op<T> for Variable {
fn compute(&self, ctx: &mut crate::op::ComputeContext<T>) -> Result<(), OpError> {
if ctx.inputs().is_empty() {
ctx.append_output(NdArray::zeros(vec![]));
} else {
let input = ctx.input(0).to_owned();
ctx.append_output(input);
}
Ok(())
}
fn grad(&self, ctx: &mut crate::op::GradientContext<T>) {
let gy = ctx.output_grad();
ctx.append_input_grad(0, Some(*gy));
}
}
pub struct Const;
impl<T: Float> crate::op::Op<T> for Const {
fn compute(&self, ctx: &mut crate::op::ComputeContext<T>) -> Result<(), OpError> {
if ctx.inputs().is_empty() {
ctx.append_output(NdArray::zeros(vec![]));
} else {
let input = ctx.input(0).to_owned();
ctx.append_output(input);
}
Ok(())
}
fn grad(&self, ctx: &mut crate::op::GradientContext<T>) {
ctx.append_input_grad(0, None);
}
}
pub struct Placeholder;
impl<T: Float> crate::op::Op<T> for Placeholder {
fn compute(&self, ctx: &mut crate::op::ComputeContext<T>) -> Result<(), OpError> {
if ctx.inputs().is_empty() {
ctx.append_output(NdArray::zeros(vec![]));
} else {
let input = ctx.input(0).to_owned();
ctx.append_output(input);
}
Ok(())
}
fn grad(&self, ctx: &mut crate::op::GradientContext<T>) {
let gy = ctx.output_grad();
ctx.append_input_grad(0, Some(*gy));
}
}