1use std::string::{String};
2use std::sync::{Arc};
3use std::ops::{Mul, Add};
4use math::{Vec2};
5use node::{Graph};
6use tensor::{Tensor};
7use context::{Context};
8
9pub struct State {
10 id: String,
11 dim: Vec2,
12}
13
14impl State {
15 pub fn new(node_id: String, dimensions: Vec2) -> State {
16 State {
17 id: node_id,
18 dim: dimensions,
19 }
20 }
21
22 pub fn get_id(&self) -> String {
23 self.id.clone()
24 }
25
26 pub fn init_norm_f64(&self, context: &mut Context<f64>) {
27 context.set(self.get_id(), Tensor::<f64>::from_gaussian(self.dim, self.dim.0));
28 }
29
30 pub fn init_norm_f32(&self, context: &mut Context<f32>) {
31 context.set(self.get_id(), Tensor::<f32>::from_gaussian(self.dim, self.dim.0));
32 }
33
34 pub fn init_f64(vec_states: Vec<Arc<State>>, context: &mut Context<f64>) {
35 for state in vec_states {
36 state.init_norm_f64(context);
37 }
38 }
39
40 pub fn init_f32(vec_states: Vec<Arc<State>>, context: &mut Context<f32>) {
41 for state in vec_states {
42 state.init_norm_f32(context);
43 }
44 }
45}
46
47impl <T> Graph<T> for State where T: Copy + Mul<Output=T> + Add<Output=T> {
48 fn get_id(&self) -> String {
49 self.id.clone()
50 }
51
52 fn get_dim(&self) -> Vec2 {
53 self.dim
54 }
55
56 fn run(&self, state: &Context<T>, _: &Context<T>) -> Tensor<T> {
57 match state.get(self.get_id()) {
58 Some(x) => x.clone(),
59 None => panic!("State {} does not exist in state", self.get_id()),
60 }
61 }
62
63 fn forward_pass(&self, state: &Context<T>, _: &Context<T>, _: &mut Context<T>) -> Tensor<T> {
64 match state.get(self.get_id()) {
65 Some(x) => x.clone(),
66 None => panic!("State {} does not exist in state", self.get_id()),
67 }
68 }
69
70 fn backward_pass(&self, state: &mut Context<T>, _: &Context<T>, history: &Context<T>, gradient: &Tensor<T>, learning_rate: T) {
71 let delta = gradient * &learning_rate;
72 let previous_state = match history.get(self.get_id()) {
73 Some(x) => x,
74 None => panic!("State {} does not exist in state", self.get_id()),
75 };
76 state.set(self.get_id(), previous_state + &delta);
77 }
78}