1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
use std::string::{String};
use std::collections::{HashMap};
use tensor::{Tensor};
use node::{Graph};
use math::{Vec2};

/// Context Map
pub struct Context<T> {
    /// map of nodeids and values
    map: HashMap<String, Tensor<T>>,
}

impl <T> Context<T> where T: Copy {
    pub fn new() -> Context<T> {
        Context {
            map: HashMap::new(),
        }
    }

    pub fn with_capacity(size: usize) -> Context<T> {
        Context {
            map: HashMap::with_capacity(size),
        }
    }

    pub fn from_vec(context_vec: Vec<(&Graph<T>, Tensor<T>)>) -> Context<T> {
        let mut context_map = HashMap::with_capacity(context_vec.len());

        for (node, batch) in context_vec {
            let Vec2(x1, y1) = node.get_dim();
            let Vec2(x2, y2) = batch.dim();
            assert_eq!(x1, x2);
            assert_eq!(y1, y2);
            context_map.insert(node.get_id(), batch);
        }

        Context {
            map: context_map
        }
    }

    pub fn get(&self, nodeid: String) -> Option<&Tensor<T>> {
        self.map.get(&nodeid)
    }

    pub fn set(&mut self, nodeid: String, tensor: Tensor<T>) {
        self.map.insert(nodeid, tensor);
    }
}