use crate::ndarray::compat::ArrayStatCompat;
use ::ndarray::{Array, ArrayD, Dimension, IxDyn};
use std::cell::RefCell;
use std::collections::{HashMap, HashSet};
use std::rc::Rc;
use crate::array_protocol::operations::matmul;
use crate::array_protocol::{ArrayProtocol, NdarrayWrapper};
use crate::error::{CoreError, CoreResult, ErrorContext};
#[derive(Clone)]
pub struct GradientDict {
gradients: HashMap<String, Box<dyn ArrayProtocol>>,
}
impl std::fmt::Debug for GradientDict {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GradientDict")
.field(
"gradients",
&format!("{{keys: {:?}}}", self.gradients.keys().collect::<Vec<_>>()),
)
.finish()
}
}
impl GradientDict {
pub fn new() -> Self {
Self {
gradients: HashMap::new(),
}
}
pub fn insert(&mut self, name: String, gradient: Box<dyn ArrayProtocol>) {
self.gradients.insert(name, gradient);
}
pub fn get(&self, name: &str) -> Option<&dyn ArrayProtocol> {
self.gradients.get(name).map(|b| b.as_ref())
}
pub fn get_mut(&mut self, name: &str) -> Option<&mut Box<dyn ArrayProtocol>> {
self.gradients.get_mut(name)
}
pub fn iter(&self) -> impl Iterator<Item = (&String, &Box<dyn ArrayProtocol>)> {
self.gradients.iter()
}
pub fn merge(&mut self, other: GradientDict) {
for (name, gradient) in other.gradients {
self.gradients.insert(name, gradient);
}
}
pub fn is_empty(&self) -> bool {
self.gradients.is_empty()
}
pub fn len(&self) -> usize {
self.gradients.len()
}
pub fn clear(&mut self) {
self.gradients.clear();
}
pub fn keys(&self) -> impl Iterator<Item = &String> {
self.gradients.keys()
}
pub fn values(&self) -> impl Iterator<Item = &Box<dyn ArrayProtocol>> {
self.gradients.values()
}
}
impl Default for GradientDict {
fn default() -> Self {
Self::new()
}
}
#[allow(dead_code)]
fn boxed_to_rc(boxed: Box<dyn ArrayProtocol>) -> Rc<dyn ArrayProtocol> {
let array_ref = boxed.as_ref();
if let Some(ndarray_wrapper) = array_ref
.as_any()
.downcast_ref::<NdarrayWrapper<f64, IxDyn>>()
{
let array_clone = ndarray_wrapper.as_array().clone();
return Rc::new(NdarrayWrapper::new(array_clone));
}
let fallback_array = ArrayD::<f64>::zeros(IxDyn(&[1, 1]));
Rc::new(NdarrayWrapper::new(fallback_array))
}
#[allow(dead_code)]
fn box_to_rc_array_protocol(boxed: Box<dyn ArrayProtocol>) -> Rc<dyn ArrayProtocol> {
boxed_to_rc(boxed)
}
#[allow(dead_code)]
fn add(a: &dyn ArrayProtocol, b: &dyn ArrayProtocol) -> CoreResult<Box<dyn ArrayProtocol>> {
crate::array_protocol::operations::add(a, b).map_err(|e| e.into())
}
#[allow(dead_code)]
fn multiply(a: &dyn ArrayProtocol, b: &dyn ArrayProtocol) -> CoreResult<Box<dyn ArrayProtocol>> {
crate::array_protocol::operations::multiply(a, b).map_err(|e| e.into())
}
#[allow(dead_code)]
fn subtract(a: &dyn ArrayProtocol, b: &dyn ArrayProtocol) -> CoreResult<Box<dyn ArrayProtocol>> {
crate::array_protocol::operations::subtract(a, b).map_err(|e| e.into())
}
#[allow(dead_code)]
fn ones_like(a: &dyn ArrayProtocol) -> CoreResult<Box<dyn ArrayProtocol>> {
if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>() {
let shape = a_array.as_array().shape();
let ones = ArrayD::<f64>::ones(IxDyn(shape));
Ok(Box::new(NdarrayWrapper::new(ones)) as Box<dyn ArrayProtocol>)
} else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f32, IxDyn>>() {
let shape = a_array.as_array().shape();
let ones = ArrayD::<f32>::ones(IxDyn(shape));
Ok(Box::new(NdarrayWrapper::new(ones)) as Box<dyn ArrayProtocol>)
} else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<i32, IxDyn>>() {
let shape = a_array.as_array().shape();
let ones = ArrayD::<i32>::ones(IxDyn(shape));
Ok(Box::new(NdarrayWrapper::new(ones)) as Box<dyn ArrayProtocol>)
} else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<i64, IxDyn>>() {
let shape = a_array.as_array().shape();
let ones = ArrayD::<i64>::ones(IxDyn(shape));
Ok(Box::new(NdarrayWrapper::new(ones)) as Box<dyn ArrayProtocol>)
} else {
let shape = a.shape().to_vec();
let ones = ArrayD::<f64>::ones(IxDyn(&shape));
Ok(Box::new(NdarrayWrapper::new(ones)) as Box<dyn ArrayProtocol>)
}
}
#[allow(dead_code)]
fn broadcast_to(a: &dyn ArrayProtocol, shape: &[usize]) -> CoreResult<Box<dyn ArrayProtocol>> {
if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>() {
let array = a_array.as_array();
if array.len() == 1 {
let value = array.iter().next().cloned().unwrap_or(0.0);
let broadcasted = ArrayD::<f64>::from_elem(IxDyn(shape), value);
Ok(Box::new(NdarrayWrapper::new(broadcasted)) as Box<dyn ArrayProtocol>)
} else if array.shape() == shape {
Ok(Box::new(NdarrayWrapper::new(array.clone())) as Box<dyn ArrayProtocol>)
} else {
let inputshape = array.shape();
let _ndim_diff = shape.len().saturating_sub(inputshape.len());
let mut can_broadcast = true;
for i in 0..inputshape.len() {
let input_dim = inputshape[inputshape.len() - 1 - i];
let target_dim = shape[shape.len() - 1 - i];
if input_dim != 1 && input_dim != target_dim {
can_broadcast = false;
break;
}
}
if can_broadcast {
if let Some(broadcasted_view) = array.broadcast(IxDyn(shape)) {
let broadcasted = broadcasted_view.to_owned();
Ok(Box::new(NdarrayWrapper::new(broadcasted)) as Box<dyn ArrayProtocol>)
} else {
Err(CoreError::NotImplementedError(ErrorContext::new(
"Broadcasting failed for these shapes".to_string(),
)))
}
} else {
Err(CoreError::NotImplementedError(ErrorContext::new(
"Incompatible shapes for broadcasting".to_string(),
)))
}
}
} else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f32, IxDyn>>() {
let array = a_array.as_array();
if array.len() == 1 {
let value = array.iter().next().cloned().unwrap_or(0.0);
let broadcasted = ArrayD::<f32>::from_elem(IxDyn(shape), value);
Ok(Box::new(NdarrayWrapper::new(broadcasted)) as Box<dyn ArrayProtocol>)
} else if array.shape() == shape {
Ok(Box::new(NdarrayWrapper::new(array.clone())) as Box<dyn ArrayProtocol>)
} else if let Some(broadcasted_view) = array.broadcast(IxDyn(shape)) {
let broadcasted = broadcasted_view.to_owned();
Ok(Box::new(NdarrayWrapper::new(broadcasted)) as Box<dyn ArrayProtocol>)
} else {
Err(CoreError::NotImplementedError(ErrorContext::new(
"Broadcasting failed for these shapes".to_string(),
)))
}
} else {
let ones = ArrayD::<f64>::ones(IxDyn(shape));
Ok(Box::new(NdarrayWrapper::new(ones)) as Box<dyn ArrayProtocol>)
}
}
#[derive(Clone)]
struct Node {
value: Rc<dyn ArrayProtocol>,
grad: Option<Rc<dyn ArrayProtocol>>,
op: Option<String>,
inputs: Vec<GradientTensor>,
requiresgrad: bool,
is_leaf: bool,
}
impl Node {
fn leaf(requiresgrad: bool) -> Self {
Self {
value: Rc::new(NdarrayWrapper::new(
crate::ndarray::Array0::<f64>::zeros(()),
)) as Rc<dyn ArrayProtocol>,
grad: None,
op: None,
inputs: Vec::new(),
requiresgrad,
is_leaf: true,
}
}
fn new_op(value: Rc<dyn ArrayProtocol>, op: String, inputs: Vec<GradientTensor>) -> Self {
let requiresgrad = inputs.iter().any(|x| x.requiresgrad());
Self {
value,
grad: None,
op: Some(op),
inputs,
requiresgrad,
is_leaf: false,
}
}
}
#[derive(Clone)]
pub struct GradientTensor {
node: Rc<RefCell<Node>>,
}
impl GradientTensor {
pub fn new(value: Rc<dyn ArrayProtocol>, requiresgrad: bool) -> Self {
let mut node_inner = Node::leaf(requiresgrad);
node_inner.value = value;
node_inner.grad = None;
let node = Rc::new(RefCell::new(node_inner));
Self { node }
}
pub fn from_array<T, D>(array: Array<T, D>, requiresgrad: bool) -> Self
where
T: Clone + Send + Sync + 'static,
D: Dimension + Send + Sync + 'static,
{
let value = Rc::new(NdarrayWrapper::new(array)) as Rc<dyn ArrayProtocol>;
Self::new(value, requiresgrad)
}
pub fn value(&self) -> Rc<dyn ArrayProtocol> {
self.node.borrow().value.clone()
}
pub fn grad_2(&self) -> Option<Rc<dyn ArrayProtocol>> {
self.node.borrow().grad.clone()
}
pub fn requiresgrad(&self) -> bool {
self.node.borrow().requiresgrad
}
pub fn set_requiresgrad(&mut self, requiresgrad: bool) {
self.node.borrow_mut().requiresgrad = requiresgrad;
}
pub fn is_leaf(&self) -> bool {
self.node.borrow().is_leaf
}
fn from_op(value: Rc<dyn ArrayProtocol>, op: String, inputs: Vec<GradientTensor>) -> Self {
let node = Rc::new(RefCell::new(Node::new_op(value, op, inputs)));
Self { node }
}
pub fn set_value(&mut self, newvalue: Rc<dyn ArrayProtocol>) {
self.node.borrow_mut().grad = None; self.node.borrow_mut().value = newvalue;
}
pub fn backward(&self) -> CoreResult<()> {
let gradshape = if let Some(array) = self
.value()
.as_any()
.downcast_ref::<NdarrayWrapper<f64, IxDyn>>()
{
array.as_array().raw_dim()
} else {
crate::ndarray::IxDyn(&[1])
};
let grad_array = Array::<f64, IxDyn>::ones(gradshape);
let grad = Rc::new(NdarrayWrapper::new(grad_array)) as Rc<dyn ArrayProtocol>;
self.backward_with_grad(grad)
}
fn backward_with_grad(&self, grad: Rc<dyn ArrayProtocol>) -> CoreResult<()> {
self.node.borrow_mut().grad = Some(grad.clone());
let mut visited = HashSet::new();
let mut topo = Vec::new();
fn build_topo(
tensor: &GradientTensor,
visited: &mut HashSet<*const RefCell<Node>>,
topo: &mut Vec<GradientTensor>,
) {
let node_ptr = Rc::as_ptr(&tensor.node);
if !visited.contains(&node_ptr) {
visited.insert(node_ptr);
for input in &tensor.node.borrow().inputs {
build_topo(input, visited, topo);
}
topo.push(tensor.clone());
}
}
build_topo(self, &mut visited, &mut topo);
for node in topo.iter().rev() {
if !node.requiresgrad() {
continue;
}
let node_grad = match node.grad_2() {
Some(g) => g,
None => continue, };
if node.is_leaf() {
continue;
}
let op = match &node.node.borrow().op {
Some(op) => op.clone(),
None => continue, };
let inputs = node.node.borrow().inputs.clone();
match op.as_str() {
"add" => {
for input in &inputs {
if input.requiresgrad() {
let mut input_node = input.node.borrow_mut();
if let Some(input_grad) = &input_node.grad {
if let Ok(sum) = add(input_grad.as_ref(), node_grad.as_ref()) {
input_node.grad = Some(sum.into());
}
} else {
input_node.grad = Some(node_grad.clone());
}
}
}
}
"multiply"
if inputs.len() == 2 => {
let (a, b) = (&inputs[0], &inputs[1]);
if a.requiresgrad() {
let b_value = b.value();
if let Ok(grad_a) = multiply(node_grad.as_ref(), b_value.as_ref()) {
let mut a_node = a.node.borrow_mut();
if let Some(a_grad) = &a_node.grad {
if let Ok(sum) = add(a_grad.as_ref(), grad_a.as_ref()) {
a_node.grad = Some(box_to_rc_array_protocol(sum));
}
} else {
a_node.grad = Some(box_to_rc_array_protocol(grad_a));
}
}
}
if b.requiresgrad() {
let a_value = a.value();
if let Ok(grad_b) = multiply(node_grad.as_ref(), a_value.as_ref()) {
let mut b_node = b.node.borrow_mut();
if let Some(b_grad) = &b_node.grad {
if let Ok(sum) = add(b_grad.as_ref(), grad_b.as_ref()) {
b_node.grad = Some(box_to_rc_array_protocol(sum));
}
} else {
b_node.grad = Some(box_to_rc_array_protocol(grad_b));
}
}
}
}
"matmul"
if inputs.len() == 2 => {
let (a, b) = (&inputs[0], &inputs[1]);
if a.requiresgrad() {
if let (Some(b_array), Some(grad_out_array)) = (
b.value()
.as_any()
.downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
node_grad
.as_any()
.downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
) {
let b_array_val = b_array.as_array();
let grad_out_array_val = grad_out_array.as_array();
let b_t = b_array_val.t();
let grad_outshape = grad_out_array_val.shape();
let grad_out_rows = grad_outshape[0];
let grad_out_cols = if grad_outshape.len() > 1 {
grad_outshape.iter().skip(1).product()
} else {
1
};
let grad_out_2d = grad_out_array_val
.clone()
.into_shape_with_order((grad_out_rows, grad_out_cols))
.expect("Operation failed");
let b_tshape = b_t.shape();
let b_t_rows = b_tshape[0];
let b_t_cols = if b_tshape.len() > 1 {
b_tshape.iter().skip(1).product()
} else {
1
};
let b_t_2d = b_t
.clone()
.into_shape_with_order((b_t_rows, b_t_cols))
.expect("Operation failed");
let grad_a_val = grad_out_2d.dot(&b_t_2d);
let grad_a_dyn = grad_a_val.into_dyn();
let grad_a = NdarrayWrapper::new(grad_a_dyn);
let mut a_node = a.node.borrow_mut();
if let Some(a_grad) = &a_node.grad {
if let (Some(a_grad_array), Some(grad_a_array)) = (
a_grad
.as_any()
.downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
grad_a
.as_any()
.downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
) {
let sum = a_grad_array.as_array() + grad_a_array.as_array();
a_node.grad = Some(Rc::new(NdarrayWrapper::new(sum)));
}
} else {
a_node.grad = Some(Rc::new(grad_a));
}
}
}
if b.requiresgrad() {
if let (Some(a_array), Some(grad_out_array)) = (
a.value()
.as_any()
.downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
node_grad
.as_any()
.downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
) {
let a_array_val = a_array.as_array();
let grad_out_array_val = grad_out_array.as_array();
let a_t = a_array_val.t();
let grad_outshape = grad_out_array_val.shape();
let grad_out_rows = grad_outshape[0];
let grad_out_cols = if grad_outshape.len() > 1 {
grad_outshape.iter().skip(1).product()
} else {
1
};
let grad_out_2d = grad_out_array_val
.clone()
.into_shape_with_order((grad_out_rows, grad_out_cols))
.expect("Operation failed");
let a_tshape = a_t.shape();
let a_t_rows = a_tshape[0];
let a_t_cols = if a_tshape.len() > 1 {
a_tshape.iter().skip(1).product()
} else {
1
};
let a_t_2d = a_t
.clone()
.into_shape_with_order((a_t_rows, a_t_cols))
.expect("Operation failed");
let grad_b_val = a_t_2d.dot(&grad_out_2d);
let grad_b_dyn = grad_b_val.into_dyn();
let grad_b = NdarrayWrapper::new(grad_b_dyn);
let mut b_node = b.node.borrow_mut();
if let Some(b_grad) = &b_node.grad {
if let (Some(b_grad_array), Some(grad_b_array)) = (
b_grad
.as_any()
.downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
grad_b
.as_any()
.downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
) {
let sum = b_grad_array.as_array() + grad_b_array.as_array();
b_node.grad = Some(Rc::new(NdarrayWrapper::new(sum)));
}
} else {
b_node.grad = Some(Rc::new(grad_b));
}
}
}
}
"subtract"
if inputs.len() == 2 => {
let (a, b) = (&inputs[0], &inputs[1]);
if a.requiresgrad() {
let mut a_node = a.node.borrow_mut();
if let Some(a_grad) = &a_node.grad {
if let Ok(sum) = add(a_grad.as_ref(), node_grad.as_ref()) {
a_node.grad = Some(box_to_rc_array_protocol(sum));
}
} else {
a_node.grad = Some(node_grad.clone());
}
}
if b.requiresgrad() {
if let Ok(neg_grad) = multiply_by_scalar(node_grad.as_ref(), -1.0) {
let mut b_node = b.node.borrow_mut();
if let Some(b_grad) = &b_node.grad {
if let Ok(sum) = add(b_grad.as_ref(), neg_grad.as_ref()) {
b_node.grad = Some(box_to_rc_array_protocol(sum));
}
} else {
b_node.grad = Some(box_to_rc_array_protocol(neg_grad));
}
}
}
}
"divide"
if inputs.len() == 2 => {
let (a, b) = (&inputs[0], &inputs[1]);
if a.requiresgrad() {
let b_value = b.value();
if let Ok(grad_a) = divide(node_grad.as_ref(), b_value.as_ref()) {
let mut a_node = a.node.borrow_mut();
if let Some(a_grad) = &a_node.grad {
if let Ok(sum) = add(a_grad.as_ref(), grad_a.as_ref()) {
a_node.grad = Some(box_to_rc_array_protocol(sum));
}
} else {
a_node.grad = Some(box_to_rc_array_protocol(grad_a));
}
}
}
if b.requiresgrad() {
let a_value = a.value();
let b_value = b.value();
if let Ok(b_squared) = multiply(b_value.as_ref(), b_value.as_ref()) {
if let Ok(grad_times_a) =
multiply(node_grad.as_ref(), a_value.as_ref())
{
if let Ok(div_result) =
divide(grad_times_a.as_ref(), b_squared.as_ref())
{
if let Ok(grad_b) =
multiply_by_scalar(div_result.as_ref(), -1.0)
{
let mut b_node = b.node.borrow_mut();
if let Some(b_grad) = &b_node.grad {
if let Ok(sum) =
add(b_grad.as_ref(), grad_b.as_ref())
{
b_node.grad =
Some(box_to_rc_array_protocol(sum));
}
} else {
b_node.grad =
Some(box_to_rc_array_protocol(grad_b));
}
}
}
}
}
}
}
"sigmoid"
if inputs.len() == 1 => {
let input = &inputs[0];
if input.requiresgrad() {
let sigmoid_value = node.value();
if let Ok(ones) = ones_like(sigmoid_value.as_ref()) {
if let Ok(one_minus_sigmoid) =
subtract(ones.as_ref(), sigmoid_value.as_ref())
{
if let Ok(sigmoid_deriv) =
multiply(sigmoid_value.as_ref(), one_minus_sigmoid.as_ref())
{
if let Ok(grad_input) =
multiply(node_grad.as_ref(), sigmoid_deriv.as_ref())
{
let mut input_node = input.node.borrow_mut();
if let Some(input_grad) = &input_node.grad {
if let Ok(sum) =
add(input_grad.as_ref(), grad_input.as_ref())
{
input_node.grad =
Some(box_to_rc_array_protocol(sum));
}
} else {
input_node.grad =
Some(box_to_rc_array_protocol(grad_input));
}
}
}
}
}
}
}
"mean"
if inputs.len() == 1 => {
let input = &inputs[0];
if input.requiresgrad() {
let input_value = input.value();
if let Some(inputarray) = input_value
.as_any()
.downcast_ref::<NdarrayWrapper<f64, IxDyn>>()
{
let n_elements = inputarray.as_array().len() as f64;
if let Ok(grad_input) =
multiply_by_scalar(node_grad.as_ref(), 1.0 / n_elements)
{
if let Ok(broadcasted_grad) = broadcast_to(
grad_input.as_ref(),
inputarray.as_array().shape(),
) {
let mut input_node = input.node.borrow_mut();
if let Some(input_grad) = &input_node.grad {
if let Ok(sum) =
add(input_grad.as_ref(), broadcasted_grad.as_ref())
{
input_node.grad =
Some(box_to_rc_array_protocol(sum));
}
} else {
input_node.grad =
Some(box_to_rc_array_protocol(broadcasted_grad));
}
}
}
}
}
}
_ => {
}
}
}
Ok(())
}
pub fn detach(&self) -> Self {
GradientTensor::new(self.value(), false)
}
}
#[allow(dead_code)]
pub fn grad_add(a: &GradientTensor, b: &GradientTensor) -> CoreResult<GradientTensor> {
let a_value = a.value();
let b_value = b.value();
let result = add(a_value.as_ref(), b_value.as_ref())?;
let result_rc: Rc<dyn ArrayProtocol> = box_to_rc_array_protocol(result);
Ok(GradientTensor::from_op(
result_rc,
"add".to_string(),
vec![a.clone(), b.clone()],
))
}
#[allow(dead_code)]
pub fn grad_multiply(a: &GradientTensor, b: &GradientTensor) -> CoreResult<GradientTensor> {
let a_value = a.value();
let b_value = b.value();
let result = multiply(a_value.as_ref(), b_value.as_ref())?;
let result_rc: Rc<dyn ArrayProtocol> = box_to_rc_array_protocol(result);
Ok(GradientTensor::from_op(
result_rc,
"multiply".to_string(),
vec![a.clone(), b.clone()],
))
}
#[allow(dead_code)]
pub fn grad_matmul(a: &GradientTensor, b: &GradientTensor) -> CoreResult<GradientTensor> {
let a_value = a.value();
let b_value = b.value();
let result = matmul(a_value.as_ref(), b_value.as_ref())?;
let result_rc: Rc<dyn ArrayProtocol> = box_to_rc_array_protocol(result);
Ok(GradientTensor::from_op(
result_rc,
"matmul".to_string(),
vec![a.clone(), b.clone()],
))
}
#[allow(dead_code)]
pub fn grad_subtract(a: &GradientTensor, b: &GradientTensor) -> CoreResult<GradientTensor> {
let a_value = a.value();
let b_value = b.value();
let result = subtract(a_value.as_ref(), b_value.as_ref())?;
let result_rc: Rc<dyn ArrayProtocol> = box_to_rc_array_protocol(result);
Ok(GradientTensor::from_op(
result_rc,
"subtract".to_string(),
vec![a.clone(), b.clone()],
))
}
#[allow(dead_code)]
pub fn grad_divide(a: &GradientTensor, b: &GradientTensor) -> CoreResult<GradientTensor> {
let a_value = a.value();
let b_value = b.value();
let result = divide(a_value.as_ref(), b_value.as_ref())?;
let result_rc: Rc<dyn ArrayProtocol> = box_to_rc_array_protocol(result);
Ok(GradientTensor::from_op(
result_rc,
"divide".to_string(),
vec![a.clone(), b.clone()],
))
}
#[allow(dead_code)]
pub fn grad_sigmoid(a: &GradientTensor) -> CoreResult<GradientTensor> {
let a_value = a.value();
if let Some(a_array) = a_value
.as_any()
.downcast_ref::<NdarrayWrapper<f64, IxDyn>>()
{
let array = a_array.as_array();
let result = array.mapv(|x| 1.0 / (1.0 + (-x).exp()));
let result_wrapped = NdarrayWrapper::new(result);
let result_rc: Rc<dyn ArrayProtocol> = Rc::new(result_wrapped);
Ok(GradientTensor::from_op(
result_rc,
"sigmoid".to_string(),
vec![a.clone()],
))
} else if let Some(a_array) = a_value
.as_any()
.downcast_ref::<NdarrayWrapper<f32, IxDyn>>()
{
let array = a_array.as_array();
let result = array.mapv(|x| 1.0f32 / (1.0f32 + (-x).exp()));
let result_wrapped = NdarrayWrapper::new(result);
let result_rc: Rc<dyn ArrayProtocol> = Rc::new(result_wrapped);
Ok(GradientTensor::from_op(
result_rc,
"sigmoid".to_string(),
vec![a.clone()],
))
} else {
Err(CoreError::NotImplementedError(ErrorContext::new(
"sigmoid not implemented for this array type".to_string(),
)))
}
}
#[allow(dead_code)]
pub fn grad_mean(a: &GradientTensor) -> CoreResult<GradientTensor> {
let a_value = a.value();
if let Some(a_array) = a_value
.as_any()
.downcast_ref::<NdarrayWrapper<f64, IxDyn>>()
{
let array = a_array.as_array();
let mean_value = array.mean_or(0.0);
let result = ArrayD::<f64>::from_elem(IxDyn(&[1]), mean_value);
let result_wrapped = NdarrayWrapper::new(result);
let result_rc: Rc<dyn ArrayProtocol> = Rc::new(result_wrapped);
Ok(GradientTensor::from_op(
result_rc,
"mean".to_string(),
vec![a.clone()],
))
} else if let Some(a_array) = a_value
.as_any()
.downcast_ref::<NdarrayWrapper<f32, IxDyn>>()
{
let array = a_array.as_array();
let mean_value = array.mean_or(0.0f32);
let result = ArrayD::<f32>::from_elem(IxDyn(&[1]), mean_value);
let result_wrapped = NdarrayWrapper::new(result);
let result_rc: Rc<dyn ArrayProtocol> = Rc::new(result_wrapped);
Ok(GradientTensor::from_op(
result_rc,
"mean".to_string(),
vec![a.clone()],
))
} else if let Some(a_array) = a_value
.as_any()
.downcast_ref::<NdarrayWrapper<i32, IxDyn>>()
{
let array = a_array.as_array();
let mean_value = if array.is_empty() {
0.0f64
} else {
array.iter().map(|&x| x as f64).sum::<f64>() / array.len() as f64
};
let result = ArrayD::<f64>::from_elem(IxDyn(&[1]), mean_value);
let result_wrapped = NdarrayWrapper::new(result);
let result_rc: Rc<dyn ArrayProtocol> = Rc::new(result_wrapped);
Ok(GradientTensor::from_op(
result_rc,
"mean".to_string(),
vec![a.clone()],
))
} else if let Some(a_array) = a_value
.as_any()
.downcast_ref::<NdarrayWrapper<i64, IxDyn>>()
{
let array = a_array.as_array();
let mean_value = if array.is_empty() {
0.0f64
} else {
array.iter().map(|&x| x as f64).sum::<f64>() / array.len() as f64
};
let result = ArrayD::<f64>::from_elem(IxDyn(&[1]), mean_value);
let result_wrapped = NdarrayWrapper::new(result);
let result_rc: Rc<dyn ArrayProtocol> = Rc::new(result_wrapped);
Ok(GradientTensor::from_op(
result_rc,
"mean".to_string(),
vec![a.clone()],
))
} else if let Some(a_array) = a_value.as_any().downcast_ref::<NdarrayWrapper<u8, IxDyn>>() {
let array = a_array.as_array();
let mean_value = if array.is_empty() {
0.0f64
} else {
array.iter().map(|&x| x as f64).sum::<f64>() / array.len() as f64
};
let result = ArrayD::<f64>::from_elem(IxDyn(&[1]), mean_value);
let result_wrapped = NdarrayWrapper::new(result);
let result_rc: Rc<dyn ArrayProtocol> = Rc::new(result_wrapped);
Ok(GradientTensor::from_op(
result_rc,
"mean".to_string(),
vec![a.clone()],
))
} else if let Some(a_array) = a_value
.as_any()
.downcast_ref::<NdarrayWrapper<u16, IxDyn>>()
{
let array = a_array.as_array();
let mean_value = if array.is_empty() {
0.0f64
} else {
array.iter().map(|&x| x as f64).sum::<f64>() / array.len() as f64
};
let result = ArrayD::<f64>::from_elem(IxDyn(&[1]), mean_value);
let result_wrapped = NdarrayWrapper::new(result);
let result_rc: Rc<dyn ArrayProtocol> = Rc::new(result_wrapped);
Ok(GradientTensor::from_op(
result_rc,
"mean".to_string(),
vec![a.clone()],
))
} else if let Some(a_array) = a_value
.as_any()
.downcast_ref::<NdarrayWrapper<u32, IxDyn>>()
{
let array = a_array.as_array();
let mean_value = if array.is_empty() {
0.0f64
} else {
array.iter().map(|&x| x as f64).sum::<f64>() / array.len() as f64
};
let result = ArrayD::<f64>::from_elem(IxDyn(&[1]), mean_value);
let result_wrapped = NdarrayWrapper::new(result);
let result_rc: Rc<dyn ArrayProtocol> = Rc::new(result_wrapped);
Ok(GradientTensor::from_op(
result_rc,
"mean".to_string(),
vec![a.clone()],
))
} else if let Some(a_array) = a_value
.as_any()
.downcast_ref::<NdarrayWrapper<u64, IxDyn>>()
{
let array = a_array.as_array();
let mean_value = if array.is_empty() {
0.0f64
} else {
array.iter().map(|&x| x as f64).sum::<f64>() / array.len() as f64
};
let result = ArrayD::<f64>::from_elem(IxDyn(&[1]), mean_value);
let result_wrapped = NdarrayWrapper::new(result);
let result_rc: Rc<dyn ArrayProtocol> = Rc::new(result_wrapped);
Ok(GradientTensor::from_op(
result_rc,
"mean".to_string(),
vec![a.clone()],
))
} else {
Err(CoreError::NotImplementedError(ErrorContext::new(
"mean not implemented for this array type".to_string(),
)))
}
}
pub struct Variable {
tensor: GradientTensor,
name: String,
}
impl Variable {
pub fn new<T, D>(name: &str, array: Array<T, D>) -> Self
where
T: Clone + Send + Sync + 'static,
D: Dimension + Send + Sync + 'static,
{
let tensor = GradientTensor::from_array(array, true);
Self {
tensor,
name: name.to_string(),
}
}
pub const fn tensor(&self) -> &GradientTensor {
&self.tensor
}
pub fn value(&self) -> Rc<dyn ArrayProtocol> {
self.tensor.value()
}
pub fn grad_2(&self) -> Option<Rc<dyn ArrayProtocol>> {
self.tensor.grad_2()
}
pub fn name(&self) -> &str {
&self.name
}
pub fn set_gradient(&mut self, gradient: Box<dyn ArrayProtocol>) -> CoreResult<()> {
let gradient_rc = self.box_to_rc(gradient);
self.tensor.node.borrow_mut().grad = Some(gradient_rc);
Ok(())
}
pub fn set_value(&mut self, newvalue: Box<dyn ArrayProtocol>) {
let newvalue_rc = self.box_to_rc(newvalue);
self.tensor.set_value(newvalue_rc);
}
fn box_to_rc(&self, boxed: Box<dyn ArrayProtocol>) -> Rc<dyn ArrayProtocol> {
if let Some(ndarray_wrapper) = boxed
.as_ref()
.as_any()
.downcast_ref::<NdarrayWrapper<f64, IxDyn>>()
{
let array_clone = ndarray_wrapper.as_array().clone();
Rc::new(NdarrayWrapper::new(array_clone))
} else {
let fallback_array = ArrayD::<f64>::zeros(IxDyn(&[1, 1]));
Rc::new(NdarrayWrapper::new(fallback_array))
}
}
}
pub trait Optimizer {
fn step(&mut self) -> CoreResult<()>;
fn zero_grad(&mut self);
fn add_variable(&mut self, var: Variable);
fn variables(&self) -> &[Variable];
fn accumulate_gradients(&mut self, gradients: &GradientDict) -> CoreResult<()> {
for (param_name, gradient) in gradients.iter() {
for var in self.variables_mut() {
if var.name() == param_name {
var.set_gradient(gradient.clone())?;
break;
}
}
}
Ok(())
}
fn variables_mut(&mut self) -> &mut [Variable] {
&mut []
}
}
pub struct SGD {
variables: Vec<Variable>,
learningrate: f64,
momentum: f64,
velocity: Vec<Option<Box<dyn ArrayProtocol>>>,
}
impl SGD {
pub fn new(learningrate: f64, momentum: Option<f64>) -> Self {
Self {
variables: Vec::new(),
learningrate,
momentum: momentum.unwrap_or(0.0),
velocity: Vec::new(),
}
}
pub fn set_learningrate(&mut self, learningrate: f64) {
self.learningrate = learningrate;
}
}
impl Optimizer for SGD {
fn step(&mut self) -> CoreResult<()> {
for (i, var) in self.variables.iter_mut().enumerate() {
if let Some(grad) = var.grad_2() {
let var_value = var.value();
let update = if self.momentum > 0.0 {
if i >= self.velocity.len() {
self.velocity.resize_with(i + 1, || None);
}
if let Some(vel) = &self.velocity[i] {
let scaled_grad = multiply_by_scalar(grad.as_ref(), self.learningrate)?;
let scaled_vel = multiply_by_scalar(vel.as_ref(), self.momentum)?;
let update = add(scaled_vel.as_ref(), scaled_grad.as_ref())?;
self.velocity[i] = Some(update.clone());
update
} else {
let update = multiply_by_scalar(grad.as_ref(), self.learningrate)?;
self.velocity[i] = Some(update.clone());
update
}
} else {
multiply_by_scalar(grad.as_ref(), self.learningrate)?
};
let updated_value = subtract_arrays(var_value.as_ref(), update.as_ref())?;
var.set_value(updated_value);
}
}
Ok(())
}
fn zero_grad(&mut self) {
for var in &self.variables {
var.tensor.node.borrow_mut().grad = None;
}
}
fn add_variable(&mut self, var: Variable) {
self.variables.push(var);
self.velocity.push(None);
}
fn variables(&self) -> &[Variable] {
&self.variables
}
fn variables_mut(&mut self) -> &mut [Variable] {
&mut self.variables
}
}
pub struct Adam {
variables: Vec<Variable>,
learningrate: f64,
beta1: f64,
beta2: f64,
epsilon: f64,
m: Vec<Option<Box<dyn ArrayProtocol>>>,
v: Vec<Option<Box<dyn ArrayProtocol>>>,
t: usize,
}
impl Adam {
pub fn new(
learningrate: f64,
beta1: Option<f64>,
beta2: Option<f64>,
epsilon: Option<f64>,
) -> Self {
Self {
variables: Vec::new(),
learningrate,
beta1: beta1.unwrap_or(0.9),
beta2: beta2.unwrap_or(0.999),
epsilon: epsilon.unwrap_or(1e-8),
m: Vec::new(),
v: Vec::new(),
t: 0,
}
}
}
impl Optimizer for Adam {
fn step(&mut self) -> CoreResult<()> {
self.t += 1;
for (i, var) in self.variables.iter_mut().enumerate() {
if let Some(grad) = var.grad_2() {
let var_value = var.value();
if i >= self.m.len() {
self.m.resize_with(i + 1, || None);
self.v.resize_with(i + 1, || None);
}
let m = if let Some(m_prev) = &self.m[i] {
let scaled_m = multiply_by_scalar(m_prev.as_ref(), self.beta1)?;
let scaled_grad = multiply_by_scalar(grad.as_ref(), 1.0 - self.beta1)?;
add(scaled_m.as_ref(), scaled_grad.as_ref())?
} else {
multiply_by_scalar(grad.as_ref(), 1.0 - self.beta1)?
};
let v = if let Some(v_prev) = &self.v[i] {
let scaled_v = multiply_by_scalar(v_prev.as_ref(), self.beta2)?;
let grad_squared = multiply(grad.as_ref(), grad.as_ref())?;
let scaled_grad_sq =
multiply_by_scalar(grad_squared.as_ref(), 1.0 - self.beta2)?;
add(scaled_v.as_ref(), scaled_grad_sq.as_ref())?
} else {
let grad_squared = multiply(grad.as_ref(), grad.as_ref())?;
multiply_by_scalar(grad_squared.as_ref(), 1.0 - self.beta2)?
};
self.m[i] = Some(m.clone());
self.v[i] = Some(v.clone());
let m_hat =
multiply_by_scalar(m.as_ref(), 1.0 / (1.0 - self.beta1.powi(self.t as i32)))?;
let v_hat =
multiply_by_scalar(v.as_ref(), 1.0 / (1.0 - self.beta2.powi(self.t as i32)))?;
let v_hat_sqrt = sqrt(v_hat.as_ref())?;
let v_hat_sqrt_eps = add_scalar(v_hat_sqrt.as_ref(), self.epsilon)?;
let update_dir = divide(m_hat.as_ref(), v_hat_sqrt_eps.as_ref())?;
let update = multiply_by_scalar(update_dir.as_ref(), self.learningrate)?;
let updated_value = subtract_arrays(var_value.as_ref(), update.as_ref())?;
var.set_value(updated_value);
}
}
Ok(())
}
fn zero_grad(&mut self) {
for var in &self.variables {
var.tensor.node.borrow_mut().grad = None;
}
}
fn add_variable(&mut self, var: Variable) {
self.variables.push(var);
self.m.push(None);
self.v.push(None);
}
fn variables(&self) -> &[Variable] {
&self.variables
}
fn variables_mut(&mut self) -> &mut [Variable] {
&mut self.variables
}
}
fn multiply_by_scalar(a: &dyn ArrayProtocol, scalar: f64) -> CoreResult<Box<dyn ArrayProtocol>> {
if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>() {
let inputarray = a_array.as_array();
let result = inputarray.mapv(|x| x * scalar);
Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
} else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f32, IxDyn>>() {
let inputarray = a_array.as_array();
let result = inputarray.mapv(|x| x * scalar as f32);
Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
} else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<i32, IxDyn>>() {
let inputarray = a_array.as_array();
let result = inputarray.mapv(|x| (x as f64 * scalar) as i32);
Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
} else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<i64, IxDyn>>() {
let inputarray = a_array.as_array();
let result = inputarray.mapv(|x| (x as f64 * scalar) as i64);
Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
} else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<u8, IxDyn>>() {
let inputarray = a_array.as_array();
let result = inputarray.mapv(|x| (x as f64 * scalar) as u8);
Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
} else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<u16, IxDyn>>() {
let inputarray = a_array.as_array();
let result = inputarray.mapv(|x| (x as f64 * scalar) as u16);
Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
} else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<u32, IxDyn>>() {
let inputarray = a_array.as_array();
let result = inputarray.mapv(|x| (x as f64 * scalar) as u32);
Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
} else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<u64, IxDyn>>() {
let inputarray = a_array.as_array();
let result = inputarray.mapv(|x| (x as f64 * scalar) as u64);
Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
} else {
Err(CoreError::NotImplementedError(ErrorContext::new(
"multiply_by_scalar not implemented for this array type".to_string(),
)))
}
}
fn subtract_arrays(
a: &dyn ArrayProtocol,
b: &dyn ArrayProtocol,
) -> CoreResult<Box<dyn ArrayProtocol>> {
if let (Some(a_wrapper), Some(b_array)) = (
a.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
b.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
) {
let a_arr = a_wrapper.as_array();
let b_arr = b_array.as_array();
let result = a_arr - b_arr;
Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
} else if let (Some(a_wrapper), Some(b_array)) = (
a.as_any().downcast_ref::<NdarrayWrapper<f32, IxDyn>>(),
b.as_any().downcast_ref::<NdarrayWrapper<f32, IxDyn>>(),
) {
let a_arr = a_wrapper.as_array();
let b_arr = b_array.as_array();
let result = a_arr - b_arr;
Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
} else if let (Some(a_wrapper), Some(b_array)) = (
a.as_any().downcast_ref::<NdarrayWrapper<i32, IxDyn>>(),
b.as_any().downcast_ref::<NdarrayWrapper<i32, IxDyn>>(),
) {
let a_arr = a_wrapper.as_array();
let b_arr = b_array.as_array();
let result = a_arr - b_arr;
Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
} else if let (Some(a_wrapper), Some(b_array)) = (
a.as_any().downcast_ref::<NdarrayWrapper<i64, IxDyn>>(),
b.as_any().downcast_ref::<NdarrayWrapper<i64, IxDyn>>(),
) {
let a_arr = a_wrapper.as_array();
let b_arr = b_array.as_array();
let result = a_arr - b_arr;
Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
} else if let (Some(a_wrapper), Some(b_array)) = (
a.as_any().downcast_ref::<NdarrayWrapper<u8, IxDyn>>(),
b.as_any().downcast_ref::<NdarrayWrapper<u8, IxDyn>>(),
) {
let a_arr = a_wrapper.as_array();
let b_arr = b_array.as_array();
let result = a_arr - b_arr;
Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
} else if let (Some(a_wrapper), Some(b_array)) = (
a.as_any().downcast_ref::<NdarrayWrapper<u16, IxDyn>>(),
b.as_any().downcast_ref::<NdarrayWrapper<u16, IxDyn>>(),
) {
let a_arr = a_wrapper.as_array();
let b_arr = b_array.as_array();
let result = a_arr - b_arr;
Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
} else if let (Some(a_wrapper), Some(b_array)) = (
a.as_any().downcast_ref::<NdarrayWrapper<u32, IxDyn>>(),
b.as_any().downcast_ref::<NdarrayWrapper<u32, IxDyn>>(),
) {
let a_arr = a_wrapper.as_array();
let b_arr = b_array.as_array();
let result = a_arr - b_arr;
Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
} else if let (Some(a_wrapper), Some(b_array)) = (
a.as_any().downcast_ref::<NdarrayWrapper<u64, IxDyn>>(),
b.as_any().downcast_ref::<NdarrayWrapper<u64, IxDyn>>(),
) {
let a_arr = a_wrapper.as_array();
let b_arr = b_array.as_array();
let result = a_arr - b_arr;
Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
} else {
Err(CoreError::NotImplementedError(ErrorContext::new(
"subtract_arrays not implemented for these array types".to_string(),
)))
}
}
fn sqrt(a: &dyn ArrayProtocol) -> CoreResult<Box<dyn ArrayProtocol>> {
if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>() {
let result = a_array.as_array().mapv(|x| x.sqrt());
Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
} else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f32, IxDyn>>() {
let result = a_array.as_array().mapv(|x| x.sqrt());
Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
} else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<i32, IxDyn>>() {
let result = a_array.as_array().mapv(|x| (x as f64).sqrt());
Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
} else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<i64, IxDyn>>() {
let result = a_array.as_array().mapv(|x| (x as f64).sqrt());
Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
} else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<u8, IxDyn>>() {
let result = a_array.as_array().mapv(|x| (x as f64).sqrt());
Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
} else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<u16, IxDyn>>() {
let result = a_array.as_array().mapv(|x| (x as f64).sqrt());
Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
} else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<u32, IxDyn>>() {
let result = a_array.as_array().mapv(|x| (x as f64).sqrt());
Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
} else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<u64, IxDyn>>() {
let result = a_array.as_array().mapv(|x| (x as f64).sqrt());
Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
} else {
Err(CoreError::NotImplementedError(ErrorContext::new(
"sqrt not implemented for this array type".to_string(),
)))
}
}
fn add_scalar(a: &dyn ArrayProtocol, scalar: f64) -> CoreResult<Box<dyn ArrayProtocol>> {
if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>() {
let result = a_array.as_array().mapv(|x| x + scalar);
Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
} else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f32, IxDyn>>() {
let result = a_array.as_array().mapv(|x| x + scalar as f32);
Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
} else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<i32, IxDyn>>() {
let result = a_array.as_array().mapv(|x| x + scalar as i32);
Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
} else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<i64, IxDyn>>() {
let result = a_array.as_array().mapv(|x| x + scalar as i64);
Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
} else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<u8, IxDyn>>() {
let result = a_array.as_array().mapv(|x| x + scalar as u8);
Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
} else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<u16, IxDyn>>() {
let result = a_array.as_array().mapv(|x| x + scalar as u16);
Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
} else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<u32, IxDyn>>() {
let result = a_array.as_array().mapv(|x| x + scalar as u32);
Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
} else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<u64, IxDyn>>() {
let result = a_array.as_array().mapv(|x| x + scalar as u64);
Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
} else {
Err(CoreError::NotImplementedError(ErrorContext::new(
"add_scalar not implemented for this array type".to_string(),
)))
}
}
fn divide(a: &dyn ArrayProtocol, b: &dyn ArrayProtocol) -> CoreResult<Box<dyn ArrayProtocol>> {
if let (Some(a_array), Some(b_array)) = (
a.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
b.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
) {
let result = a_array.as_array() / b_array.as_array();
Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
} else if let (Some(a_array), Some(b_array)) = (
a.as_any().downcast_ref::<NdarrayWrapper<f32, IxDyn>>(),
b.as_any().downcast_ref::<NdarrayWrapper<f32, IxDyn>>(),
) {
let result = a_array.as_array() / b_array.as_array();
Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
} else if let (Some(a_array), Some(b_array)) = (
a.as_any().downcast_ref::<NdarrayWrapper<i32, IxDyn>>(),
b.as_any().downcast_ref::<NdarrayWrapper<i32, IxDyn>>(),
) {
let result = ::ndarray::Zip::from(a_array.as_array())
.and(b_array.as_array())
.map_collect(|&av, &bv| av as f64 / bv as f64);
Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
} else if let (Some(a_array), Some(b_array)) = (
a.as_any().downcast_ref::<NdarrayWrapper<i64, IxDyn>>(),
b.as_any().downcast_ref::<NdarrayWrapper<i64, IxDyn>>(),
) {
let result = ::ndarray::Zip::from(a_array.as_array())
.and(b_array.as_array())
.map_collect(|&av, &bv| av as f64 / bv as f64);
Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
} else if let (Some(a_array), Some(b_array)) = (
a.as_any().downcast_ref::<NdarrayWrapper<u8, IxDyn>>(),
b.as_any().downcast_ref::<NdarrayWrapper<u8, IxDyn>>(),
) {
let result = ::ndarray::Zip::from(a_array.as_array())
.and(b_array.as_array())
.map_collect(|&av, &bv| av as f64 / bv as f64);
Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
} else if let (Some(a_array), Some(b_array)) = (
a.as_any().downcast_ref::<NdarrayWrapper<u16, IxDyn>>(),
b.as_any().downcast_ref::<NdarrayWrapper<u16, IxDyn>>(),
) {
let result = ::ndarray::Zip::from(a_array.as_array())
.and(b_array.as_array())
.map_collect(|&av, &bv| av as f64 / bv as f64);
Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
} else if let (Some(a_array), Some(b_array)) = (
a.as_any().downcast_ref::<NdarrayWrapper<u32, IxDyn>>(),
b.as_any().downcast_ref::<NdarrayWrapper<u32, IxDyn>>(),
) {
let result = ::ndarray::Zip::from(a_array.as_array())
.and(b_array.as_array())
.map_collect(|&av, &bv| av as f64 / bv as f64);
Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
} else if let (Some(a_array), Some(b_array)) = (
a.as_any().downcast_ref::<NdarrayWrapper<u64, IxDyn>>(),
b.as_any().downcast_ref::<NdarrayWrapper<u64, IxDyn>>(),
) {
let result = ::ndarray::Zip::from(a_array.as_array())
.and(b_array.as_array())
.map_collect(|&av, &bv| av as f64 / bv as f64);
Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
} else {
Err(CoreError::NotImplementedError(ErrorContext::new(
"divide not implemented for these array types".to_string(),
)))
}
}
#[cfg(test)]
#[path = "grad_tests.rs"]
mod tests;