use crate::ops::gradient_descent_ops::sgd;
use crate::tensor::{Input, Tensor};
use crate::Float;
use crate::Graph;
pub struct SGD<T: Float> {
pub lr: T,
}
impl<'b, T: Float> SGD<T> {
pub fn compute_updates(
&self,
params: Vec<Tensor<'b, T>>,
grads: Vec<Tensor<'b, T>>,
c: &'b Graph<T>,
) -> Vec<Tensor<'b, T>> {
let len = params.len();
let mut ret = Vec::with_capacity(len);
for i in 0..len {
ret.push(
Tensor::builder()
.set_inputs(&[Input::new_mut(¶ms[i]), Input::new(&grads[i])])
.build(c, sgd::SGDOp::new(self.lr)),
);
}
ret
}
}