auto_diff/
compute_graph.rs

1#![allow(clippy::redundant_closure)]
2use std::collections::{BTreeSet, BTreeMap};
3use std::fmt;
4
5use crate::collection::generational_index::{GenIndex, GenKey};
6use crate::collection::directed_graph::Graph;
7use tensor_rs::tensor::Tensor;
8use crate::op::Op;
9use crate::err::AutoDiffError;
10
11#[cfg(feature = "use-serde")]
12use serde::{Serialize, Deserialize};
13
14/// The computation network.
15/// Connection has duplication.
16#[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
17#[derive(Clone)]
18pub struct Net {
19    data: GenIndex<Tensor>,
20    ops: GenIndex<Op>,
21    set_mark: BTreeSet<GenKey>,
22    graph: Graph<GenKey, GenKey>,
23    data_grad: BTreeMap<GenKey, Tensor>,
24    label2id: BTreeMap<String, GenKey>, // Give some var a name.
25}
26
27impl Net {
28    pub fn new() -> Net {
29        Net {
30            data: GenIndex::new(),
31            ops: GenIndex::new(),
32            set_mark: BTreeSet::new(),
33            graph: Graph::new(),
34            data_grad: BTreeMap::new(),
35	    label2id: BTreeMap::new(),
36        }
37    }
38
39    pub fn get_data(&self) -> &GenIndex<Tensor> {
40        &self.data
41    }
42
43    pub fn get_data_mut(&mut self) -> &mut GenIndex<Tensor> {
44        &mut self.data
45    }
46    pub fn get_ops(&self) -> &GenIndex<Op> {
47        &self.ops
48    }
49    pub fn get_ops_mut(&mut self) -> &mut GenIndex<Op> {
50        &mut self.ops
51    }
52
53    pub fn add_tensor(&mut self, t: Tensor) -> GenKey {
54        let id = self.data.insert(t);
55        self.graph.add_data(&id).expect("");
56        id
57    }
58
59    pub fn get_tensor(&self, id: GenKey) -> Result<Tensor, AutoDiffError> {
60        match self.data.get(&id) {
61            Ok(v) => {Ok(v.ref_copy())}, // shallow copy a tensor.
62            Err(v) => {Err(v)}
63        }
64    }
65    pub fn set_tensor(&mut self, id: GenKey, val: Tensor) -> Result<(), AutoDiffError> {
66        self.data.replace(&id, val)
67    }
68
69    /// Insert operator into the network.
70    pub fn add_op(&mut self, op: Op) -> GenKey {
71        let id = self.ops.insert(op);
72        self.graph.add_op(&id).expect("");
73        id
74    }
75    pub fn get_op(&self, id: GenKey) -> Result<Op, AutoDiffError> {
76        Ok(self.ops.get(&id)?.ref_copy())
77    }
78
79    pub fn get_grad(&self, id: GenKey) -> Result<Tensor, AutoDiffError> {
80        match self.data_grad.get(&id) {
81            Some(v) => {Ok(v.ref_copy())},
82            None => {Err(AutoDiffError::new(&format!("Data {:?} doesn't ahave gradient yet.", id)))}
83        }
84    }
85
86    pub fn get_input_edge_data(&self) -> BTreeSet<GenKey> {
87        self.graph.get_input_edge_data()
88    }
89
90    pub fn get_output_edge_data(&self) -> BTreeSet<GenKey> {
91        self.graph.get_output_edge_data()
92    }
93
94
95//    pub fn is_dangling_var(&self, var: &Var) -> Result<bool, ()> {
96//        if !self.data.contains(var.get_id()) {
97//            Err(())
98//        } else if self.graph.iter_op_given_input(var.get_id()).expect("").count() == 0 &&
99//            self.graph.iter_op_given_output(var.get_id()).expect("").count() == 0{
100//                Ok(true)
101//            } else {
102//                Ok(false)
103//            }
104//
105//    }
106
107
108//    ///
109//    /// Remove a concrete op or composed func from the graph.
110//    ///
111//    pub fn del_func_or_op(&mut self, func: &Func) {
112//        let _ = self.ops.remove(func.get_id());
113//        let _ = self.graph.drop_op(func.get_id());
114//
115//        // ignore the result as to allow duplicate delete
116//
117//        //
118//        // The following dosen't work 
119//        // because if the composed goes out of scope doesn't mean
120//        //     its member ops go out of scope.
121//        //
122//        // Check to see the func type.
123//        // If it is a op, delete it
124//        // If it is a func, find all the underlying op
125//        //     and var in between and remove them.
126//        //
127//
128//    }
129
130//    ///
131//    /// Disconnect the variable and the function the variable is the input.
132//    /// Delete the variable if it becomes the dangling variable.
133//    ///
134//    pub fn decouple_input(&mut self, func: &Func) -> Vec<GenKey> {
135//        let mut decoupled_inputs = Vec::new();
136//        let inputs: Vec<GenKey> = self.graph.iter_input_given_op(func.get_id())
137//            .expect("").map(|x| x.clone()).collect();
138//        for i in inputs {
139//            self.graph.decouple_data_func(&i, func.get_id()).expect("");
140//            decoupled_inputs.push(i);
141//        }
142//        decoupled_inputs
143//    }
144
145
146    ///
147    /// Build input-operator-output relation, with given components.
148    ///
149    pub fn connect(&mut self, input: &[GenKey], op: GenKey, output: &[GenKey]) {
150
151        self.graph.connect(input, output, &op).expect("");
152    }
153
154
155    /// set the set_mark, set_mark is used to label var with input value with it.
156    pub fn set_mark(&mut self, did: &GenKey) {
157        self.set_mark.insert(*did);
158    }
159    pub fn unset_mark(&mut self, did: &GenKey) {
160        self.set_mark.remove(did);
161    }
162
163    /// Forward evaluate the computaiton graph.
164    pub fn eval(&mut self, starting_node: &[GenKey]) -> Result<(), BTreeSet<GenKey>> {
165        
166        self.graph
167            .walk(
168                starting_node,
169                true,
170                |input, output, op| {
171                    //println!("op: {}", self.ops.get(op).expect("").get_name());
172                    
173                    let mut inputs: Vec<Tensor> = Vec::new();
174                    for input_id in input {
175                        let a = self.data.get(input_id).expect("").ref_copy();
176                        inputs.push(a);
177                    }
178
179                    let mut outputs: Vec<Tensor> = Vec::new();
180                    for output_id in output {
181                        let a = self.data.get(output_id).expect("").ref_copy();
182                        outputs.push(a);
183                    }
184
185                    self.ops
186                        .get(op)
187                        .expect("")
188                        .apply(&inputs, &outputs);
189                    
190                    //println!("var.rs: {:?}", outputs[0].size());
191                    
192                }
193            )?;
194
195        Ok(())
196    }
197
198//    pub fn eval_op(&self, input: &[&Var], func: &Func, output: &[&Var]) {
199//        let mut inputs: Vec<&Tensor> = Vec::new();
200//        for input_var in input {
201//            let a = self.data.get(input_var.get_id()).expect("");
202//            inputs.push(a);
203//        }
204//
205//        let mut outputs: Vec<&Tensor> = Vec::new();
206//        for output_var in output {
207//            let a = self.data.get(output_var.get_id()).expect("");
208//            outputs.push(a);
209//        }
210//
211//        self.ops
212//            .get(func.get_id())
213//            .expect("")
214//            .apply(&inputs, &outputs);
215//    }
216
217//    pub fn bptt_scale(&mut self, r: f32) {
218//        let output = self.graph.get_output_edge_data();
219//        let mut output_grad = BTreeMap::new();
220//        for i in &output {
221//            output_grad.insert(*i,
222//                               Tensor::fill(&self.data.get(i).expect("").size(),
223//                                            r));
224//        }
225//        self.bptt(&output_grad);
226//    }
227
228    pub fn bptt(&mut self, output_grad: &BTreeMap<GenKey, Tensor>) {
229        let mut output = Vec::new();
230        self.data_grad.clear();
231        for (k, v) in output_grad {
232            output.push(*k);
233            self.data_grad.insert(*k, v.clone());
234        }
235
236        for i in self.graph.iter_data() {
237            self.data_grad.entry(*i).or_insert_with(Tensor::new);
238        }
239
240        self.graph
241            .walk(
242                &output[..],
243                false,
244                |output_grads, input_grads, op| {
245                    //println!("op, bptt: {}", self.ops.get(op).expect("").get_name());
246
247                    // collect input tensor.
248                    let mut inputs: Vec<Tensor> = Vec::new();
249                    for input_id in input_grads {
250                        //println!("bptt {:?}", input_id);
251                        let a = self.data.get(input_id).expect("").ref_copy();
252                        inputs.push(a);
253                    }
254                    //println!("input: size {:?}", inputs.len());
255
256                    // collect the output tensor gradient (forward view).
257                    let mut output_grad: Vec<Tensor> = Vec::new();
258                    for output_id in output_grads {
259                        //println!("bptt 2 {:?}", output_id);
260                        let a = self.data_grad.get(output_id).expect("").ref_copy();
261                        output_grad.push(a);
262                    }
263                    //println!("output grad: size {:?}", output_grad.len());
264                    
265                    // collect the input tensor gradient (forward view).
266                    let mut input_grad: Vec<Tensor> = Vec::new();
267                    for input_id in input_grads {
268                        //println!("bptt 3 {:?}", input_id);
269                        let a = self.data_grad.get(input_id).expect("").ref_copy();
270                        input_grad.push(a);
271                    }
272                    //println!("input grad: size {:?}", input_grad.len());
273
274                    self.ops
275                        .get(op)
276                        .expect("")
277                        .grad(&inputs, &output_grad, &input_grad);
278                    
279                    //println!("var.rs: {:?}", 1);
280                    
281                }
282            ).expect("");
283    }
284
285    /// Iterate over all ops, no order guarantee
286    pub fn visit_op<F>(&mut self, closure: F,
287                       allow: Option<Vec<GenKey>>,
288                       skip: Option<Vec<GenKey>>)
289    where F: Fn(&Op) {
290        let allow_list = if let Some(val) = allow { val } else {Vec::new()};
291        let skip_list = if let Some(val) = skip {val} else {Vec::new()};
292        
293        for i in self.graph.iter_op() {
294            if (allow_list.is_empty() && skip_list.is_empty()) ||
295                (!allow_list.is_empty() && allow_list.contains(i)) ||
296                (!skip_list.is_empty() && !skip_list.contains(i) ) {
297                    closure(self.ops.get(i).expect(""));
298            }
299        }
300    }
301
302    pub fn visit_data<F>(&mut self, closure: F)
303    where F: Fn(GenKey, &Tensor) {
304        for i in self.graph.iter_data() {
305            closure(*i, self.data.get(i).expect(""));
306        }
307    }
308
309    /// Move content in other network into self.
310    /// Return new ids for those have origianl_keys in the old network.
311    pub fn append(&mut self, other: &Self,
312                  original_keys: &[GenKey]) -> Result<Vec<GenKey>, AutoDiffError> {
313
314        let mut data_key_map = BTreeMap::new();
315        let mut ret_keys = Vec::new();
316        for key in other.get_data().iter_key() {
317            let new_key = self.add_tensor(other.get_tensor(key)?);
318            if original_keys.contains(&key) {
319                ret_keys.push(new_key);
320            }
321            data_key_map.insert(key, new_key);
322        }
323        
324        let mut op_key_map = BTreeMap::new();
325        for key in other.get_ops().iter_key() {
326            let new_key = self.add_op(other.get_op(key)?);
327            op_key_map.insert(key, new_key);
328        }
329
330        self.graph.append(&other.graph, data_key_map, op_key_map)?;
331
332        Ok(ret_keys)
333    }
334
335    /// For introspection.
336    pub fn set_label(&mut self, label: &str, id: &GenKey) -> Result<(), AutoDiffError>{
337	if !self.data.contains(id) {
338	    Err(AutoDiffError::new("unknown id."))
339	} else {
340	    self.label2id.insert(label.to_string(), *id);
341	    Ok(())
342	}
343    }
344
345    pub fn get_id_by_label(&self, label: &str) -> Result<GenKey, AutoDiffError> {
346	match self.label2id.get(label) {
347            Some(v) => {Ok(*v)},
348            None => {Err(AutoDiffError::new("unknown label."))}
349        }
350    }
351
352    pub fn drop_label(&mut self, label: &str) -> Result<GenKey, AutoDiffError> {
353	if !self.label2id.contains_key(label) {
354	    Err(AutoDiffError::new("unknown label."))
355	} else {
356	    Ok(*self.label2id.get(label).expect("unknown label."))
357	}
358    }
359}
360
361impl fmt::Debug for Net {
362    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
363        writeln!(f, "Dumping Net:")?;
364        for i in self.data.iter_key() {
365            writeln!(f, "id: {:?}  data: {:?}", i, self.data.get(&i)?)?;
366        }
367        writeln!(f, "data_grad")?;
368        for (k, v) in self.data_grad.iter() {
369            writeln!(f, "id: {:?}  data: {:?}", k, v)?;
370        }
371        writeln!(f, "op names")?;
372        for i in self.ops.iter_key() {
373            writeln!(f, "id: {:?} \n data: {:?}", i, self.ops.get(&i)?.get_name())?;
374        }
375        writeln!(f, "graph: {:?}", self.graph)
376    }
377}
378
379impl Default for Net {
380    fn default() -> Self {
381        Self::new()
382    }
383}
384
385//impl PartialEq for Net {
386//    fn eq(&self, other: &Self) -> bool {
387//	unimplemented!();
388//    }
389//}
390//
391//impl Eq for Net {}