rai_core/transforms/
raiexpr.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
use crate::{Func, Shape, Tensor, TensorIter, Value};
use colored::*;
use std::collections::{BTreeSet, HashMap};

pub fn raiexpr<K, IN, OUT, F>(func: &F, input: IN) -> String
where
    F: Func<K, IN, OUT>,
    IN: Value,
    OUT: Value,
{
    let mut id_seq = 0;
    let mut id_map: HashMap<usize, usize> = HashMap::new();

    let in_tensors = input.tensors();
    let output = func.invoke(input);
    let out_tensors = output.tensors();

    fn id(id_map: &mut HashMap<usize, usize>, id_seq: &mut usize, t: &Tensor) -> String {
        let id = *id_map.entry(t.id()).or_insert_with(|| {
            *id_seq += 1;
            *id_seq
        });
        format!("%{}", id).yellow().to_string()
    }

    fn decl(id_map: &mut HashMap<usize, usize>, id_seq: &mut usize, t: &Tensor) -> String {
        format!("{}:{}", id(id_map, id_seq, t), ty(t))
    }

    fn ty(t: &Tensor) -> String {
        format!(
            "{}{}",
            format!("{:?}", t.dtype()).to_lowercase().cyan(),
            format!("{:?}", t.shape()).purple(),
        )
    }

    #[cfg(not(feature = "debug-location"))]
    fn expr(id_map: &mut HashMap<usize, usize>, id_seq: &mut usize, t: &Tensor) -> String {
        format!(
            "\t{} = {} {}",
            decl(id_map, id_seq, t),
            t.primitive().dot_label(),
            t.inputs()
                .iter()
                .map(|v| decl(id_map, id_seq, v))
                .collect::<Vec<_>>()
                .join(" ")
        )
    }

    #[cfg(feature = "debug-location")]
    fn expr(id_map: &mut HashMap<usize, usize>, id_seq: &mut usize, t: &Tensor) -> String {
        format!(
            "\t{} = {} {} // {}",
            decl(id_map, id_seq, t),
            t.primitive().dot_label(),
            t.inputs()
                .iter()
                .map(|v| decl(id_map, id_seq, v))
                .collect::<Vec<_>>()
                .join(" "),
            t.location()
        )
    }

    // use iterative instead of recursive to avoid stack overflow
    // TODO: use proper topo sort algorithm, now sort by id in BTreeSet
    let input_set = in_tensors.tensor_iter().cloned().collect::<BTreeSet<_>>();
    let mut tape = BTreeSet::new();
    let mut stack = Vec::new();
    for output in out_tensors.tensor_iter() {
        stack.push(output.clone());
    }

    while let Some(t) = stack.pop() {
        if tape.contains(&t) || input_set.contains(&t) {
            continue;
        }
        tape.insert(t.clone());
        for input in t.inputs().iter() {
            stack.push(input.clone());
        }
    }

    let inputs = in_tensors
        .tensor_iter()
        .map(|t| decl(&mut id_map, &mut id_seq, t))
        .collect::<Vec<String>>()
        .join(", ");

    let outputs = out_tensors
        .tensor_iter()
        .map(ty)
        .collect::<Vec<String>>()
        .join(", ");

    let body = tape
        .iter()
        .map(|t| expr(&mut id_map, &mut id_seq, t))
        .collect::<Vec<String>>()
        .join("\n");

    let returns = out_tensors
        .tensor_iter()
        .map(|t| id(&mut id_map, &mut id_seq, t))
        .collect::<Vec<String>>()
        .join(", ");

    format!(
        "fn({}) -> ({}) {{\n{}\n\treturn ({})\n}}",
        inputs, outputs, body, returns
    )
}