auto_diff/collection/
directed_graph.rs

1//! A directed graph implementation with interleave op node and data node
2//! and all the edges are data node.
3use std::collections::{BTreeMap, BTreeSet};
4use std::fmt;
5use crate::err::AutoDiffError;
6use super::generational_index::GenKey;
7
8#[cfg(feature = "use-serde")]
9use serde::{Serialize, Deserialize};
10
11/// Graph
12#[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
13#[derive(Clone)]
14pub struct Graph<TData: Ord, TOp: Ord> {
15    data: BTreeSet<TData>,
16    op: BTreeSet<TOp>,
17    forward_dt_op: BTreeMap<TData, BTreeSet<TOp>>,
18    forward_op_dt: BTreeMap<TOp, BTreeSet<TData>>,
19    backward_dt_op: BTreeMap<TData, BTreeSet<TOp>>,
20    backward_op_dt: BTreeMap<TOp, BTreeSet<TData>>,
21}
22
23impl<TData: Clone + Copy + Ord, TOp: Clone + Copy + Ord> Default for Graph<TData, TOp> {
24    fn default() -> Graph<TData, TOp> {
25        Graph{
26            data: BTreeSet::new(),
27            op: BTreeSet::new(),
28            forward_dt_op: BTreeMap::new(),
29            forward_op_dt: BTreeMap::new(),
30            backward_dt_op: BTreeMap::new(),
31            backward_op_dt: BTreeMap::new(),
32        }
33    }
34}
35
36impl<TData: Clone + Copy + Ord, TOp: Clone + Copy + Ord> Graph<TData, TOp> {
37    /// Create a graph with defaults
38    pub fn new() -> Graph<TData, TOp> {
39        Graph{
40            data: BTreeSet::new(),
41            op: BTreeSet::new(),
42            forward_dt_op: BTreeMap::new(),
43            forward_op_dt: BTreeMap::new(),
44            backward_dt_op: BTreeMap::new(),
45            backward_op_dt: BTreeMap::new(),
46        }
47    }
48
49    /// iterator over data node.
50    pub fn iter_data(&self) -> NodeIterator<TData> {
51        NodeIterator {
52            iter: self.data.iter()
53        }
54    }
55    /// iterator over op node.
56    pub fn iter_op(&self) -> NodeIterator<TOp> {
57        NodeIterator {
58            iter: self.op.iter()
59        }
60    }
61
62    ///
63    /// Return the list of ops that the given variable is the input of the func.
64    ///
65    pub fn iter_op_given_input(&self, var: &TData) -> Result<NodeIterator<TOp>, &str> {
66        if !self.data.contains(var) {
67            Err("Not a valid variable/data")
68        } else {
69            Ok(NodeIterator {
70                iter: self.forward_dt_op.get(var).expect("").iter()
71            })
72        }
73    }
74
75    ///
76    /// Return the list of ops that the given variable is the output.
77    ///
78    pub fn iter_op_given_output(&self, var: &TData) -> Result<NodeIterator<TOp>, &str> {
79        if !self.data.contains(var) {
80            Err("Not a valid variable/data")
81        } else {
82            Ok(NodeIterator {
83                iter: self.backward_dt_op.get(var).expect("").iter()
84            })
85        }
86    }
87
88    ///
89    /// Return the list of input given the func.
90    ///
91    pub fn iter_input_given_op(&self, func: &TOp) -> Result<NodeIterator<TData>, &str> {
92        if !self.op.contains(func) {
93            Err("Bad func id.")
94        } else {
95            Ok(NodeIterator {
96                iter: self.backward_op_dt.get(func).expect("").iter()
97            })
98        }
99    }
100
101    ///
102    /// Return a list of data as the output of the op.
103    ///
104    pub fn iter_output_given_op(&self, func: &TOp) -> Result<NodeIterator<TData>, &str> {
105        if !self.op.contains(func) {
106            Err("Bad func id.")
107        } else {
108            Ok(NodeIterator {
109                iter: self.forward_op_dt.get(func).expect("").iter()
110            })
111        }
112    }
113
114    /// Add a data node.
115    pub fn add_data(&mut self, id: &TData) -> Result<TData, &str> {
116        if !self.data.contains(id) {
117            self.data.insert(*id);
118            self.forward_dt_op.insert(*id, BTreeSet::new());
119            self.backward_dt_op.insert(*id, BTreeSet::new());
120            Ok(*id)
121        } else {
122            Err("data is exits!")
123        }
124    }
125
126    /// Remove a data node, op node and downstream data/op node are removed.
127    pub fn drop_data(&mut self, id: &TData) -> Result<TData, &str> {
128        if self.data.contains(id) {
129            self.data.remove(id);
130            for i in self.forward_dt_op.get_mut(id).expect("").iter() {
131                self.backward_op_dt.get_mut(i).expect("").remove(id);
132            }
133            self.forward_dt_op.remove(id);
134            for i in self.backward_dt_op.get_mut(id).expect("").iter() {
135                self.forward_op_dt.get_mut(i).expect("").remove(id);
136            }
137            self.backward_dt_op.remove(id);
138
139            Ok(*id)
140        } else {
141            Err("data id is not found!")
142        }
143    }
144
145    /// Add a danglging op node.
146    pub fn add_op(&mut self, id: &TOp) -> Result<TOp, &str> {
147        if !self.op.contains(id) {
148            self.op.insert(*id);
149            self.forward_op_dt.insert(*id, BTreeSet::new());
150            self.backward_op_dt.insert(*id, BTreeSet::new());
151            Ok(*id)
152        } else {
153            Err("op id exists.")
154        }
155    }
156
157    /// Remvoe an op node, input data node and downstream data/op node are removed.
158    pub fn drop_op(&mut self, id: &TOp) -> Result<TOp, &str> {
159        if self.op.contains(id) {
160            self.op.remove(id);
161            for i in self.forward_op_dt.get_mut(id).expect("").iter() {
162                self.backward_dt_op.get_mut(i).expect("").remove(id);
163            }
164            self.forward_op_dt.remove(id);
165            for i in self.backward_op_dt.get_mut(id).expect("").iter() {
166                self.forward_dt_op.get_mut(i).expect("").remove(id);
167            }
168            self.backward_op_dt.remove(id);
169            Ok(*id)
170        } else {
171            Err("op id is not found!")
172        }
173
174    }
175
176    ///
177    /// Decouple input variable and op
178    ///
179    pub fn decouple_data_func(&mut self, var: &TData, func: &TOp) -> Result<(), AutoDiffError> {
180        if self.data.contains(var) && self.op.contains(func) {
181            self.forward_dt_op.get_mut(var).expect("").remove(func);
182            self.backward_op_dt.get_mut(func).expect("").remove(var);
183            Ok(())
184        } else {
185            Err(AutoDiffError::new("invalid var or func"))
186        }
187    }
188
189    ///
190    /// Decouple op and output variable
191    ///
192    pub fn decouple_func_data(&mut self, func: &TOp, var: &TData) -> Result<(), AutoDiffError> {
193        if self.data.contains(var) && self.op.contains(func) {
194            self.forward_op_dt.get_mut(func).expect("").remove(var);
195            self.backward_dt_op.get_mut(var).expect("").remove(func);
196            Ok(())
197        } else {
198            Err(AutoDiffError::new("invalid var or func"))
199        }
200    }
201
202    /// list data node without upstream op node in a set.
203    pub fn get_input_edge_data(&self) -> BTreeSet<TData> {
204        let mut jobs = BTreeSet::new();
205        for i in &self.data {
206            if self.backward_dt_op.get(i).expect("").is_empty() {
207                jobs.insert(*i);
208            }
209        }
210        jobs
211    }
212
213    /// list data node without downstream op node in a set.
214    pub fn get_output_edge_data(&self) -> BTreeSet<TData> {
215        let mut jobs = BTreeSet::new();
216        for i in &self.data {
217            if self.forward_dt_op.get(i).expect("").is_empty() {
218                jobs.insert(*i);
219            }
220        }
221        jobs
222    }
223
224    /// Connect input data, output data and operation
225    pub fn connect(&mut self, dti: &[TData],
226                   dto: &[TData],
227                   op: &TOp) -> Result<TOp, &str> {
228        let mut valid_ids = true;
229
230        // make sure pre-exist
231        if !self.op.contains(op) {
232            valid_ids = false;
233        }
234        // make sure input data pre-exist
235        for i in dti {
236            if !self.data.contains(i) {
237                valid_ids = false;
238            }
239        }
240        // make sure output data pre-exist
241        for i in dto {
242            if !self.data.contains(i) {
243                valid_ids = false;
244            }
245        }
246        
247        if valid_ids {
248            for i in dti {
249                self.forward_dt_op.get_mut(i).expect("").insert(*op);
250                self.backward_op_dt.get_mut(op).expect("").insert(*i);
251            }
252            for i in dto {
253                self.forward_op_dt.get_mut(op).expect("").insert(*i);
254                self.backward_dt_op.get_mut(i).expect("").insert(*op);
255            }
256            Ok(*op)
257        } else {
258            Err("Invalid id!")
259        }
260    }
261
262    /// Auxilary connect, This allows the graph to support loop.
263    pub fn connect_aux(&mut self, input_data: &[TData],
264                       output_data: &[TData],
265                       op: &TOp) -> Result<TOp, &str> {
266        if !self.op.contains(op) ||
267            input_data.iter().any(|x| !self.data.contains(x)) ||
268            output_data.iter().any(|x| !self.data.contains(x)) {
269                return Err("Invalid id!");
270            }
271        unimplemented!();
272        //return Ok(*op);
273    }
274
275    ///
276    /// Walk through the graph with a starting set of data nodes.
277    /// Go through backwards if forward is false.
278    /// The closure call provides the side-effect.
279    ///
280    /// This Walk() guarantee the input of visiting op is already visited
281    /// or it's an input.
282    ///
283    pub fn walk<F>(&self, start_set: &[TData],
284                   forward: bool,
285                   closure: F) -> Result<(), BTreeSet<TData>>
286    where F: Fn(&[TData], &[TData], &TOp)  {
287        let mut fdo = &self.forward_dt_op;
288        let mut fod = &self.forward_op_dt;
289        //let mut bdo = &self.backward_dt_op;
290        let mut bod = &self.backward_op_dt;
291        if !forward {
292            fdo = &self.backward_dt_op;
293            fod = &self.backward_op_dt;
294            //bdo = &self.forward_dt_op;
295            bod = &self.forward_op_dt;
296        }
297
298        // data id has a value
299        let mut jobs = BTreeSet::<TData>::new();
300        // op is done.
301        let mut done = BTreeSet::<TOp>::new(); // ops done.
302
303        for index in start_set {
304            jobs.insert(*index);
305        }
306        
307        loop {
308            let mut made_progress = false;
309
310            // collect ops needs to do given the data in jobs.
311            let mut edge_op = BTreeSet::<TOp>::new();
312            for dt in &jobs {
313                for op_candidate in &fdo[dt] {
314                    edge_op.insert(*op_candidate);
315                }
316            }
317
318            // process op if possible
319            for op_candidate in edge_op {
320                if bod[&op_candidate]
321                    .iter()
322                    .all(|dt| jobs.contains(dt)) {
323
324                        // collect input ids.
325                        let mut inputs = Vec::<TData>::new();
326                        for input in bod[&op_candidate].iter() {
327                            inputs.push(*input);
328                        }
329                        // collect output ids.
330                        let mut outputs = Vec::<TData>::new();
331                        for output in fod[&op_candidate].iter() {
332                            outputs.push(*output);
333                        }
334
335                        // all the closure
336                        closure(&inputs, &outputs, &op_candidate);
337
338                        // maintain the list
339                        // the following line should go before the rest.
340                        done.insert(op_candidate);
341                        // remove the data from jobs if all its downstream op is done.
342                        for input in bod[&op_candidate].iter() {
343                            if fdo[input]
344                                .iter()
345                                .all(|op| done.contains(op)) {
346                                    jobs.remove(input);
347                                }
348                        }
349                        // add the output back to the jobs.
350                        for output in fod[&op_candidate].iter() {
351                            // don't add to jobs if it's the final data node.
352                            if !fdo[output].is_empty() {
353                                jobs.insert(*output);                                
354                            }
355                        }
356
357                        // flag there is sth done.
358                        made_progress = true;
359                    }
360            }
361
362            if ! made_progress {
363                break;
364            }
365        }
366
367        if !jobs.is_empty() {
368            Err(jobs)
369        } else {
370            Ok(())
371        }
372    }
373
374    /////
375    ///// Walk through the graph with a starting set of data nodes.
376    ///// Go through backwards if forward is false.
377    /////
378    //pub fn walk_dyn(&self, start_set: &[TData],
379    //               forward: bool,
380    //                closure: dyn Calling) -> Result<(), BTreeSet<TData>> {
381    //    Ok(())
382    //}
383
384    /// 
385    pub fn append(&mut self, other: &Self,
386                  data_key_map: BTreeMap<TData, TData>,
387                  op_key_map: BTreeMap<TOp, TOp>) -> Result<(), AutoDiffError> {
388
389        for key in other.iter_data() {
390            self.data.insert(data_key_map[key]);
391        }
392        for key in other.iter_op() {
393            self.op.insert(op_key_map[key]);
394        }
395        for (key, value) in other.forward_dt_op.iter() {
396            let mut new_set = BTreeSet::new();
397            for key in value.iter() {
398                new_set.insert(op_key_map[key]);
399            }
400            self.forward_dt_op.insert(data_key_map[key], new_set);
401        }
402        for (key, value) in other.backward_dt_op.iter() {
403            let mut new_set = BTreeSet::new();
404            for key in value.iter() {
405                new_set.insert(op_key_map[key]);
406            }
407            self.backward_dt_op.insert(data_key_map[key], new_set);
408        }
409        for (key, value) in other.forward_op_dt.iter() {
410            let mut new_set = BTreeSet::new();
411            for key in value.iter() {
412                new_set.insert(data_key_map[key]);
413            }
414            self.forward_op_dt.insert(op_key_map[key], new_set);
415        }
416        for (key, value) in other.backward_op_dt.iter() {
417            let mut new_set = BTreeSet::new();
418            for key in value.iter() {
419                new_set.insert(data_key_map[key]);
420            }
421            self.backward_op_dt.insert(op_key_map[key], new_set);
422        }
423
424        
425        Ok(())
426    }
427}
428
429// iterator
430pub struct NodeIterator<'a, TNode> {
431    iter: std::collections::btree_set::Iter<'a, TNode>,
432}
433impl<'a, TNode> Iterator for NodeIterator<'a, TNode> {
434    type Item = &'a TNode;
435    fn next(&mut self) -> Option<Self::Item> {
436        self.iter.next()
437    }
438}
439
440impl fmt::Debug for Graph<GenKey, GenKey> {
441    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
442        writeln!(f, "Dumping graph")?;
443        writeln!(f, "data: {:?}", self.data)?;
444        writeln!(f, "op: {:?}", self.op)?;
445        writeln!(f, "dt 2 op: {:?}", self.forward_dt_op)?;
446        writeln!(f, "op 2 dt: {:?}", self.forward_op_dt)
447    }
448}
449
450impl<T1: Ord, T2: Ord> PartialEq for Graph<T1, T2> {
451    fn eq(&self, other: &Self) -> bool {
452	self.data.eq(&other.data) &&
453	    self.op.eq(&other.op) &&
454	    self.forward_dt_op.eq(&other.forward_dt_op) &&
455	    self.forward_op_dt.eq(&other.forward_op_dt) &&
456	    self.backward_dt_op.eq(&other.backward_dt_op) &&
457	    self.backward_op_dt.eq(&other.backward_op_dt)
458    }
459}
460
461impl<T1: Ord, T2: Ord> Eq for Graph<T1, T2> {}
462
463
464#[cfg(test)]
465mod tests {
466    use super::*;
467    use crate::collection::generational_index::{GenKey};
468    
469    #[test]
470    fn new() {
471        let _g = Graph::<GenKey, GenKey>::new();
472    }
473
474    // A   B
475    //  \ /
476    //   Op
477    //   |
478    //   C
479    fn setup_y(g: &mut Graph<GenKey, GenKey>) {
480        let data_a = GenKey::new(0,0);
481        let data_b = GenKey::new(1,0);
482        let data_c = GenKey::new(2,0);
483        g.add_data(&data_a).expect("");
484        g.add_data(&data_b).expect("");
485        g.add_data(&data_c).expect("");
486        
487        let op_a = GenKey::new(0,0);
488        g.add_op(&op_a).expect("");
489
490        g.connect(&[data_a, data_b], &[data_c,], &op_a).expect("");
491    }
492
493    // A   B
494    //  \ /
495    //   Op1
496    //   |
497    //   C   D
498    //    \ /
499    //     Op2
500    //     |
501    //     E
502    fn setup_yy(g: &mut Graph<GenKey, GenKey>) {
503        let data_a = GenKey::new(0,0);
504        let data_b = GenKey::new(1,0);
505        let data_c = GenKey::new(2,0);
506        let data_d = GenKey::new(3,0);
507        let data_e = GenKey::new(4,0);
508        g.add_data(&data_a).expect("");
509        g.add_data(&data_b).expect("");
510        g.add_data(&data_c).expect("");
511        g.add_data(&data_d).expect("");
512        g.add_data(&data_e).expect("");
513        
514        let op1 = GenKey::new(0,0);
515        g.add_op(&op1).expect("");
516        let op2 = GenKey::new(1,0);
517        g.add_op(&op2).expect("");
518
519        g.connect(&[data_a, data_b], &[data_c,], &op1).expect("");
520        g.connect(&[data_c, data_d], &[data_e,], &op2).expect("");
521    }
522
523    #[test]
524    fn iter() {
525        let mut g = Graph::new();
526        setup_yy(&mut g);
527        
528        for i in g.iter_data() {
529            println!("{:?}", i);
530        }
531
532        for i in g.iter_op() {
533            println!("{:?}", i);
534        }
535    }
536
537    #[test]
538    fn test_get_input_cache() {
539        let mut g = Graph::new();
540        setup_y(&mut g);
541        assert_eq!(g.get_input_edge_data().len(), 2);
542
543        let mut g = Graph::<GenKey, GenKey>::new();
544        setup_yy(&mut g);
545        assert_eq!(g.get_input_edge_data().len(), 3);
546    }
547
548    #[test]
549    fn test_get_output_cache() {
550        let mut g = Graph::new();
551        setup_y(&mut g);
552        assert_eq!(g.get_output_edge_data().len(), 1);
553
554        let mut g = Graph::<GenKey, GenKey>::new();
555        setup_yy(&mut g);
556        assert_eq!(g.get_output_edge_data().len(), 1);
557    }
558
559    #[test]
560    fn add_data() {
561
562        let mut g = Graph::<GenKey, GenKey>::new();
563        let data1 = GenKey::new(0,0);
564        let data2 = GenKey::new(1,0);
565        g.add_data(&data1).expect("");
566        g.add_data(&data2).expect("");
567    }
568}
569