1use std::collections::{BTreeMap, BTreeSet};
4use std::fmt;
5use crate::err::AutoDiffError;
6use super::generational_index::GenKey;
7
8#[cfg(feature = "use-serde")]
9use serde::{Serialize, Deserialize};
10
11#[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
13#[derive(Clone)]
14pub struct Graph<TData: Ord, TOp: Ord> {
15 data: BTreeSet<TData>,
16 op: BTreeSet<TOp>,
17 forward_dt_op: BTreeMap<TData, BTreeSet<TOp>>,
18 forward_op_dt: BTreeMap<TOp, BTreeSet<TData>>,
19 backward_dt_op: BTreeMap<TData, BTreeSet<TOp>>,
20 backward_op_dt: BTreeMap<TOp, BTreeSet<TData>>,
21}
22
23impl<TData: Clone + Copy + Ord, TOp: Clone + Copy + Ord> Default for Graph<TData, TOp> {
24 fn default() -> Graph<TData, TOp> {
25 Graph{
26 data: BTreeSet::new(),
27 op: BTreeSet::new(),
28 forward_dt_op: BTreeMap::new(),
29 forward_op_dt: BTreeMap::new(),
30 backward_dt_op: BTreeMap::new(),
31 backward_op_dt: BTreeMap::new(),
32 }
33 }
34}
35
36impl<TData: Clone + Copy + Ord, TOp: Clone + Copy + Ord> Graph<TData, TOp> {
37 pub fn new() -> Graph<TData, TOp> {
39 Graph{
40 data: BTreeSet::new(),
41 op: BTreeSet::new(),
42 forward_dt_op: BTreeMap::new(),
43 forward_op_dt: BTreeMap::new(),
44 backward_dt_op: BTreeMap::new(),
45 backward_op_dt: BTreeMap::new(),
46 }
47 }
48
49 pub fn iter_data(&self) -> NodeIterator<TData> {
51 NodeIterator {
52 iter: self.data.iter()
53 }
54 }
55 pub fn iter_op(&self) -> NodeIterator<TOp> {
57 NodeIterator {
58 iter: self.op.iter()
59 }
60 }
61
62 pub fn iter_op_given_input(&self, var: &TData) -> Result<NodeIterator<TOp>, &str> {
66 if !self.data.contains(var) {
67 Err("Not a valid variable/data")
68 } else {
69 Ok(NodeIterator {
70 iter: self.forward_dt_op.get(var).expect("").iter()
71 })
72 }
73 }
74
75 pub fn iter_op_given_output(&self, var: &TData) -> Result<NodeIterator<TOp>, &str> {
79 if !self.data.contains(var) {
80 Err("Not a valid variable/data")
81 } else {
82 Ok(NodeIterator {
83 iter: self.backward_dt_op.get(var).expect("").iter()
84 })
85 }
86 }
87
88 pub fn iter_input_given_op(&self, func: &TOp) -> Result<NodeIterator<TData>, &str> {
92 if !self.op.contains(func) {
93 Err("Bad func id.")
94 } else {
95 Ok(NodeIterator {
96 iter: self.backward_op_dt.get(func).expect("").iter()
97 })
98 }
99 }
100
101 pub fn iter_output_given_op(&self, func: &TOp) -> Result<NodeIterator<TData>, &str> {
105 if !self.op.contains(func) {
106 Err("Bad func id.")
107 } else {
108 Ok(NodeIterator {
109 iter: self.forward_op_dt.get(func).expect("").iter()
110 })
111 }
112 }
113
114 pub fn add_data(&mut self, id: &TData) -> Result<TData, &str> {
116 if !self.data.contains(id) {
117 self.data.insert(*id);
118 self.forward_dt_op.insert(*id, BTreeSet::new());
119 self.backward_dt_op.insert(*id, BTreeSet::new());
120 Ok(*id)
121 } else {
122 Err("data is exits!")
123 }
124 }
125
126 pub fn drop_data(&mut self, id: &TData) -> Result<TData, &str> {
128 if self.data.contains(id) {
129 self.data.remove(id);
130 for i in self.forward_dt_op.get_mut(id).expect("").iter() {
131 self.backward_op_dt.get_mut(i).expect("").remove(id);
132 }
133 self.forward_dt_op.remove(id);
134 for i in self.backward_dt_op.get_mut(id).expect("").iter() {
135 self.forward_op_dt.get_mut(i).expect("").remove(id);
136 }
137 self.backward_dt_op.remove(id);
138
139 Ok(*id)
140 } else {
141 Err("data id is not found!")
142 }
143 }
144
145 pub fn add_op(&mut self, id: &TOp) -> Result<TOp, &str> {
147 if !self.op.contains(id) {
148 self.op.insert(*id);
149 self.forward_op_dt.insert(*id, BTreeSet::new());
150 self.backward_op_dt.insert(*id, BTreeSet::new());
151 Ok(*id)
152 } else {
153 Err("op id exists.")
154 }
155 }
156
157 pub fn drop_op(&mut self, id: &TOp) -> Result<TOp, &str> {
159 if self.op.contains(id) {
160 self.op.remove(id);
161 for i in self.forward_op_dt.get_mut(id).expect("").iter() {
162 self.backward_dt_op.get_mut(i).expect("").remove(id);
163 }
164 self.forward_op_dt.remove(id);
165 for i in self.backward_op_dt.get_mut(id).expect("").iter() {
166 self.forward_dt_op.get_mut(i).expect("").remove(id);
167 }
168 self.backward_op_dt.remove(id);
169 Ok(*id)
170 } else {
171 Err("op id is not found!")
172 }
173
174 }
175
176 pub fn decouple_data_func(&mut self, var: &TData, func: &TOp) -> Result<(), AutoDiffError> {
180 if self.data.contains(var) && self.op.contains(func) {
181 self.forward_dt_op.get_mut(var).expect("").remove(func);
182 self.backward_op_dt.get_mut(func).expect("").remove(var);
183 Ok(())
184 } else {
185 Err(AutoDiffError::new("invalid var or func"))
186 }
187 }
188
189 pub fn decouple_func_data(&mut self, func: &TOp, var: &TData) -> Result<(), AutoDiffError> {
193 if self.data.contains(var) && self.op.contains(func) {
194 self.forward_op_dt.get_mut(func).expect("").remove(var);
195 self.backward_dt_op.get_mut(var).expect("").remove(func);
196 Ok(())
197 } else {
198 Err(AutoDiffError::new("invalid var or func"))
199 }
200 }
201
202 pub fn get_input_edge_data(&self) -> BTreeSet<TData> {
204 let mut jobs = BTreeSet::new();
205 for i in &self.data {
206 if self.backward_dt_op.get(i).expect("").is_empty() {
207 jobs.insert(*i);
208 }
209 }
210 jobs
211 }
212
213 pub fn get_output_edge_data(&self) -> BTreeSet<TData> {
215 let mut jobs = BTreeSet::new();
216 for i in &self.data {
217 if self.forward_dt_op.get(i).expect("").is_empty() {
218 jobs.insert(*i);
219 }
220 }
221 jobs
222 }
223
224 pub fn connect(&mut self, dti: &[TData],
226 dto: &[TData],
227 op: &TOp) -> Result<TOp, &str> {
228 let mut valid_ids = true;
229
230 if !self.op.contains(op) {
232 valid_ids = false;
233 }
234 for i in dti {
236 if !self.data.contains(i) {
237 valid_ids = false;
238 }
239 }
240 for i in dto {
242 if !self.data.contains(i) {
243 valid_ids = false;
244 }
245 }
246
247 if valid_ids {
248 for i in dti {
249 self.forward_dt_op.get_mut(i).expect("").insert(*op);
250 self.backward_op_dt.get_mut(op).expect("").insert(*i);
251 }
252 for i in dto {
253 self.forward_op_dt.get_mut(op).expect("").insert(*i);
254 self.backward_dt_op.get_mut(i).expect("").insert(*op);
255 }
256 Ok(*op)
257 } else {
258 Err("Invalid id!")
259 }
260 }
261
262 pub fn connect_aux(&mut self, input_data: &[TData],
264 output_data: &[TData],
265 op: &TOp) -> Result<TOp, &str> {
266 if !self.op.contains(op) ||
267 input_data.iter().any(|x| !self.data.contains(x)) ||
268 output_data.iter().any(|x| !self.data.contains(x)) {
269 return Err("Invalid id!");
270 }
271 unimplemented!();
272 }
274
275 pub fn walk<F>(&self, start_set: &[TData],
284 forward: bool,
285 closure: F) -> Result<(), BTreeSet<TData>>
286 where F: Fn(&[TData], &[TData], &TOp) {
287 let mut fdo = &self.forward_dt_op;
288 let mut fod = &self.forward_op_dt;
289 let mut bod = &self.backward_op_dt;
291 if !forward {
292 fdo = &self.backward_dt_op;
293 fod = &self.backward_op_dt;
294 bod = &self.forward_op_dt;
296 }
297
298 let mut jobs = BTreeSet::<TData>::new();
300 let mut done = BTreeSet::<TOp>::new(); for index in start_set {
304 jobs.insert(*index);
305 }
306
307 loop {
308 let mut made_progress = false;
309
310 let mut edge_op = BTreeSet::<TOp>::new();
312 for dt in &jobs {
313 for op_candidate in &fdo[dt] {
314 edge_op.insert(*op_candidate);
315 }
316 }
317
318 for op_candidate in edge_op {
320 if bod[&op_candidate]
321 .iter()
322 .all(|dt| jobs.contains(dt)) {
323
324 let mut inputs = Vec::<TData>::new();
326 for input in bod[&op_candidate].iter() {
327 inputs.push(*input);
328 }
329 let mut outputs = Vec::<TData>::new();
331 for output in fod[&op_candidate].iter() {
332 outputs.push(*output);
333 }
334
335 closure(&inputs, &outputs, &op_candidate);
337
338 done.insert(op_candidate);
341 for input in bod[&op_candidate].iter() {
343 if fdo[input]
344 .iter()
345 .all(|op| done.contains(op)) {
346 jobs.remove(input);
347 }
348 }
349 for output in fod[&op_candidate].iter() {
351 if !fdo[output].is_empty() {
353 jobs.insert(*output);
354 }
355 }
356
357 made_progress = true;
359 }
360 }
361
362 if ! made_progress {
363 break;
364 }
365 }
366
367 if !jobs.is_empty() {
368 Err(jobs)
369 } else {
370 Ok(())
371 }
372 }
373
374 pub fn append(&mut self, other: &Self,
386 data_key_map: BTreeMap<TData, TData>,
387 op_key_map: BTreeMap<TOp, TOp>) -> Result<(), AutoDiffError> {
388
389 for key in other.iter_data() {
390 self.data.insert(data_key_map[key]);
391 }
392 for key in other.iter_op() {
393 self.op.insert(op_key_map[key]);
394 }
395 for (key, value) in other.forward_dt_op.iter() {
396 let mut new_set = BTreeSet::new();
397 for key in value.iter() {
398 new_set.insert(op_key_map[key]);
399 }
400 self.forward_dt_op.insert(data_key_map[key], new_set);
401 }
402 for (key, value) in other.backward_dt_op.iter() {
403 let mut new_set = BTreeSet::new();
404 for key in value.iter() {
405 new_set.insert(op_key_map[key]);
406 }
407 self.backward_dt_op.insert(data_key_map[key], new_set);
408 }
409 for (key, value) in other.forward_op_dt.iter() {
410 let mut new_set = BTreeSet::new();
411 for key in value.iter() {
412 new_set.insert(data_key_map[key]);
413 }
414 self.forward_op_dt.insert(op_key_map[key], new_set);
415 }
416 for (key, value) in other.backward_op_dt.iter() {
417 let mut new_set = BTreeSet::new();
418 for key in value.iter() {
419 new_set.insert(data_key_map[key]);
420 }
421 self.backward_op_dt.insert(op_key_map[key], new_set);
422 }
423
424
425 Ok(())
426 }
427}
428
429pub struct NodeIterator<'a, TNode> {
431 iter: std::collections::btree_set::Iter<'a, TNode>,
432}
433impl<'a, TNode> Iterator for NodeIterator<'a, TNode> {
434 type Item = &'a TNode;
435 fn next(&mut self) -> Option<Self::Item> {
436 self.iter.next()
437 }
438}
439
440impl fmt::Debug for Graph<GenKey, GenKey> {
441 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
442 writeln!(f, "Dumping graph")?;
443 writeln!(f, "data: {:?}", self.data)?;
444 writeln!(f, "op: {:?}", self.op)?;
445 writeln!(f, "dt 2 op: {:?}", self.forward_dt_op)?;
446 writeln!(f, "op 2 dt: {:?}", self.forward_op_dt)
447 }
448}
449
450impl<T1: Ord, T2: Ord> PartialEq for Graph<T1, T2> {
451 fn eq(&self, other: &Self) -> bool {
452 self.data.eq(&other.data) &&
453 self.op.eq(&other.op) &&
454 self.forward_dt_op.eq(&other.forward_dt_op) &&
455 self.forward_op_dt.eq(&other.forward_op_dt) &&
456 self.backward_dt_op.eq(&other.backward_dt_op) &&
457 self.backward_op_dt.eq(&other.backward_op_dt)
458 }
459}
460
461impl<T1: Ord, T2: Ord> Eq for Graph<T1, T2> {}
462
463
464#[cfg(test)]
465mod tests {
466 use super::*;
467 use crate::collection::generational_index::{GenKey};
468
469 #[test]
470 fn new() {
471 let _g = Graph::<GenKey, GenKey>::new();
472 }
473
474 fn setup_y(g: &mut Graph<GenKey, GenKey>) {
480 let data_a = GenKey::new(0,0);
481 let data_b = GenKey::new(1,0);
482 let data_c = GenKey::new(2,0);
483 g.add_data(&data_a).expect("");
484 g.add_data(&data_b).expect("");
485 g.add_data(&data_c).expect("");
486
487 let op_a = GenKey::new(0,0);
488 g.add_op(&op_a).expect("");
489
490 g.connect(&[data_a, data_b], &[data_c,], &op_a).expect("");
491 }
492
493 fn setup_yy(g: &mut Graph<GenKey, GenKey>) {
503 let data_a = GenKey::new(0,0);
504 let data_b = GenKey::new(1,0);
505 let data_c = GenKey::new(2,0);
506 let data_d = GenKey::new(3,0);
507 let data_e = GenKey::new(4,0);
508 g.add_data(&data_a).expect("");
509 g.add_data(&data_b).expect("");
510 g.add_data(&data_c).expect("");
511 g.add_data(&data_d).expect("");
512 g.add_data(&data_e).expect("");
513
514 let op1 = GenKey::new(0,0);
515 g.add_op(&op1).expect("");
516 let op2 = GenKey::new(1,0);
517 g.add_op(&op2).expect("");
518
519 g.connect(&[data_a, data_b], &[data_c,], &op1).expect("");
520 g.connect(&[data_c, data_d], &[data_e,], &op2).expect("");
521 }
522
523 #[test]
524 fn iter() {
525 let mut g = Graph::new();
526 setup_yy(&mut g);
527
528 for i in g.iter_data() {
529 println!("{:?}", i);
530 }
531
532 for i in g.iter_op() {
533 println!("{:?}", i);
534 }
535 }
536
537 #[test]
538 fn test_get_input_cache() {
539 let mut g = Graph::new();
540 setup_y(&mut g);
541 assert_eq!(g.get_input_edge_data().len(), 2);
542
543 let mut g = Graph::<GenKey, GenKey>::new();
544 setup_yy(&mut g);
545 assert_eq!(g.get_input_edge_data().len(), 3);
546 }
547
548 #[test]
549 fn test_get_output_cache() {
550 let mut g = Graph::new();
551 setup_y(&mut g);
552 assert_eq!(g.get_output_edge_data().len(), 1);
553
554 let mut g = Graph::<GenKey, GenKey>::new();
555 setup_yy(&mut g);
556 assert_eq!(g.get_output_edge_data().len(), 1);
557 }
558
559 #[test]
560 fn add_data() {
561
562 let mut g = Graph::<GenKey, GenKey>::new();
563 let data1 = GenKey::new(0,0);
564 let data2 = GenKey::new(1,0);
565 g.add_data(&data1).expect("");
566 g.add_data(&data2).expect("");
567 }
568}
569