1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
use std::string::{String};
use std::collections::{HashMap};
use tensor::{Tensor};
use node::{Graph};
use math::{Vec2};
pub struct Context<T> {
map: HashMap<String, Tensor<T>>,
}
impl <T> Context<T> where T: Copy {
pub fn new() -> Context<T> {
Context {
map: HashMap::new(),
}
}
pub fn with_capacity(size: usize) -> Context<T> {
Context {
map: HashMap::with_capacity(size),
}
}
pub fn from_vec(context_vec: Vec<(&Graph<T>, Tensor<T>)>) -> Context<T> {
let mut context_map = HashMap::with_capacity(context_vec.len());
for (node, batch) in context_vec {
let Vec2(x1, y1) = node.get_dim();
let Vec2(x2, y2) = batch.dim();
assert_eq!(x1, x2);
assert_eq!(y1, y2);
context_map.insert(node.get_id(), batch);
}
Context {
map: context_map
}
}
pub fn get(&self, nodeid: String) -> Option<&Tensor<T>> {
self.map.get(&nodeid)
}
pub fn set(&mut self, nodeid: String, tensor: Tensor<T>) {
self.map.insert(nodeid, tensor);
}
}