1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
use std::string::{String}; use std::sync::{Arc}; use std::ops::{Mul, Add}; use math::{Vec2}; use node::{Graph}; use tensor::{Tensor}; use context::{Context}; pub struct State { id: String, dim: Vec2, } impl State { pub fn new(node_id: String, dimensions: Vec2) -> State { State { id: node_id, dim: dimensions, } } pub fn get_id(&self) -> String { self.id.clone() } pub fn init_norm_f64(&self, context: &mut Context<f64>) { context.set(self.get_id(), Tensor::<f64>::from_gaussian(self.dim, self.dim.0)); } pub fn init_norm_f32(&self, context: &mut Context<f32>) { context.set(self.get_id(), Tensor::<f32>::from_gaussian(self.dim, self.dim.0)); } pub fn init_f64(vec_states: Vec<Arc<State>>, context: &mut Context<f64>) { for state in vec_states { state.init_norm_f64(context); } } pub fn init_f32(vec_states: Vec<Arc<State>>, context: &mut Context<f32>) { for state in vec_states { state.init_norm_f32(context); } } } impl <T> Graph<T> for State where T: Copy + Mul<Output=T> + Add<Output=T> { fn get_id(&self) -> String { self.id.clone() } fn get_dim(&self) -> Vec2 { self.dim } fn run(&self, state: &Context<T>, _: &Context<T>) -> Tensor<T> { match state.get(self.get_id()) { Some(x) => x.clone(), None => panic!("State {} does not exist in state", self.get_id()), } } fn forward_pass(&self, state: &Context<T>, _: &Context<T>, _: &mut Context<T>) -> Tensor<T> { match state.get(self.get_id()) { Some(x) => x.clone(), None => panic!("State {} does not exist in state", self.get_id()), } } fn backward_pass(&self, state: &mut Context<T>, _: &Context<T>, history: &Context<T>, gradient: &Tensor<T>, learning_rate: T) { let delta = gradient * &learning_rate; let previous_state = match history.get(self.get_id()) { Some(x) => x, None => panic!("State {} does not exist in state", self.get_id()), }; state.set(self.get_id(), previous_state + &delta); } }