ktensor/node/
state.rs

1use std::string::{String};
2use std::sync::{Arc};
3use std::ops::{Mul, Add};
4use math::{Vec2};
5use node::{Graph};
6use tensor::{Tensor};
7use context::{Context};
8
9pub struct State {
10    id: String,
11    dim: Vec2,
12}
13
14impl State {
15    pub fn new(node_id: String, dimensions: Vec2) -> State {
16        State {
17            id: node_id,
18            dim: dimensions,
19        }
20    }
21
22    pub fn get_id(&self) -> String {
23        self.id.clone()
24    }
25
26    pub fn init_norm_f64(&self, context: &mut Context<f64>) {
27        context.set(self.get_id(), Tensor::<f64>::from_gaussian(self.dim, self.dim.0));
28    }
29
30    pub fn init_norm_f32(&self, context: &mut Context<f32>) {
31        context.set(self.get_id(), Tensor::<f32>::from_gaussian(self.dim, self.dim.0));
32    }
33
34    pub fn init_f64(vec_states: Vec<Arc<State>>, context: &mut Context<f64>) {
35        for state in vec_states {
36            state.init_norm_f64(context);
37        }
38    }
39
40    pub fn init_f32(vec_states: Vec<Arc<State>>, context: &mut Context<f32>) {
41        for state in vec_states {
42            state.init_norm_f32(context);
43        }
44    }
45}
46
47impl <T> Graph<T> for State where T: Copy + Mul<Output=T> + Add<Output=T> {
48    fn get_id(&self) -> String {
49        self.id.clone()
50    }
51
52    fn get_dim(&self) -> Vec2 {
53        self.dim
54    }
55
56    fn run(&self, state: &Context<T>, _: &Context<T>) -> Tensor<T> {
57        match state.get(self.get_id()) {
58            Some(x) => x.clone(),
59            None    => panic!("State {} does not exist in state", self.get_id()),
60        }
61    }
62
63    fn forward_pass(&self, state: &Context<T>, _: &Context<T>, _: &mut Context<T>) -> Tensor<T> {
64        match state.get(self.get_id()) {
65            Some(x) => x.clone(),
66            None    => panic!("State {} does not exist in state", self.get_id()),
67        }
68    }
69
70    fn backward_pass(&self, state: &mut Context<T>, _: &Context<T>, history: &Context<T>, gradient: &Tensor<T>, learning_rate: T) {
71        let delta = gradient * &learning_rate;
72        let previous_state = match history.get(self.get_id()) {
73            Some(x) => x,
74            None    => panic!("State {} does not exist in state", self.get_id()),
75        };
76        state.set(self.get_id(), previous_state + &delta);
77    }
78}