chandra/core/operations/
assign.rs

1use std::marker::PhantomData;
2
3use crate::core::{types::{Computable, Void}, operation::{Operation, OperationWrapper, Differentiable}, processor::cpu::DifferentiatedCPUContext};
4
5use super::var::Variable;
6
7pub fn assign<R: Computable, O: Operation<R>>(name: String, operation: O) -> (Variable<R>, OperationWrapper<Void, Assign<R,O>>) {
8    let variable = Variable::<R>::new(&name);
9    let instruction = OperationWrapper(Assign {
10        variable: variable.clone(),
11        assign: operation
12    },
13    PhantomData);
14
15    return (variable, instruction) 
16}
17
18#[derive(Clone, Debug)]
19pub struct Assign<R: Computable, O: Operation<R>> {
20    pub variable: Variable<R>,
21    pub assign: O
22}
23
24impl<R: Computable, O: Operation<R>> Operation<Void> for Assign<R, O> {
25    fn evaluate(&self, _context: &mut DifferentiatedCPUContext) -> Void {
26        todo!()
27       //let a = self.assign.evaluate(context);
28       //context.set::<R>(&self.variable.reference, a);
29       //Void
30    }
31}
32
33impl<R: Computable, O: Differentiable<R>> Differentiable<Void> for Assign<R, O> {
34    type Diff = OperationWrapper<Void, Assign<R, O::Diff>>;
35
36    fn auto_diff_for<R1: Clone>(&self, var: Variable<R1>, var_trace: &mut std::collections::HashMap<String, Vec<String>>) -> Self::Diff {
37        if self.assign.contains_var(var.clone()) {
38            var_trace.insert(self.variable.reference.clone(), vec![var.reference.clone()]);
39        } else {
40            var_trace.insert(self.variable.reference.clone(), vec![]);
41        }
42
43        assign(self.variable.reference.clone(), self.assign.auto_diff_for(var, var_trace)).1
44    }
45
46    fn contains_var<R1: Clone>(&self, var: Variable<R1>) -> bool {
47        self.assign.contains_var(var)
48    }
49}