1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
//! Execution path

use crate::*;
use anyhow::Result;
use std::collections::BTreeSet;

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Path {
    original: Subscripts,
    reduced_subscripts: Vec<Subscripts>,
}

impl std::ops::Deref for Path {
    type Target = [Subscripts];
    fn deref(&self) -> &[Subscripts] {
        &self.reduced_subscripts
    }
}

impl Path {
    pub fn output(&self) -> &Subscript {
        &self.original.output
    }

    pub fn num_args(&self) -> usize {
        self.original.inputs.len()
    }

    pub fn compute_order(&self) -> usize {
        compute_order(&self.reduced_subscripts)
    }

    pub fn memory_order(&self) -> usize {
        memory_order(&self.reduced_subscripts)
    }

    pub fn brute_force(indices: &str) -> Result<Self> {
        let mut names = Namespace::init();
        let subscripts = Subscripts::from_raw_indices(&mut names, indices)?;
        Ok(Path {
            original: subscripts.clone(),
            reduced_subscripts: brute_force_work(&mut names, subscripts)?,
        })
    }
}

fn compute_order(ss: &[Subscripts]) -> usize {
    ss.iter()
        .map(|ss| ss.compute_order())
        .max()
        .expect("self.0 never be empty")
}

fn memory_order(ss: &[Subscripts]) -> usize {
    ss.iter()
        .map(|ss| ss.memory_order())
        .max()
        .expect("self.0 never be empty")
}

fn brute_force_work(names: &mut Namespace, subscripts: Subscripts) -> Result<Vec<Subscripts>> {
    if subscripts.inputs.len() <= 2 {
        // Cannot be factorized anymore
        return Ok(vec![subscripts]);
    }

    let n = subscripts.inputs.len();
    let mut subpaths = (0..2_usize.pow(n as u32))
        .filter_map(|mut m| {
            // create combinations specifying which tensors are used
            let mut pos = BTreeSet::new();
            for i in 0..n {
                if m % 2 == 1 {
                    pos.insert(*subscripts.inputs[i].position());
                }
                m /= 2;
            }
            // At least two tensors, and not be all
            if pos.len() >= 2 && pos.len() < n {
                Some(pos)
            } else {
                None
            }
        })
        .map(|pos| {
            let mut names = names.clone();
            let (inner, outer) = subscripts.factorize(&mut names, pos)?;
            let mut sub = brute_force_work(&mut names, outer)?;
            sub.insert(0, inner);
            Ok(sub)
        })
        .collect::<Result<Vec<_>>>()?;
    subpaths.push(vec![subscripts]);
    Ok(subpaths
        .into_iter()
        .min_by_key(|path| (compute_order(path), memory_order(path)))
        .expect("subpath never be empty"))
}

#[cfg(test)]
mod test {
    use super::*;

    #[test]
    fn brute_force_ij_jk() -> Result<()> {
        let path = Path::brute_force("ij,jk->ik")?;
        assert_eq!(path.len(), 1);
        assert_eq!(path[0].to_string(), "ij,jk->ik | arg0,arg1->out0");
        Ok(())
    }

    #[test]
    fn brute_force_ij_jk_kl_l() -> Result<()> {
        let path = Path::brute_force("ij,jk,kl,l->i")?;
        assert_eq!(path.len(), 3);
        assert_eq!(path[0].to_string(), "kl,l->k | arg2,arg3->out1");
        assert_eq!(path[1].to_string(), "k,jk->j | out1,arg1->out2");
        assert_eq!(path[2].to_string(), "j,ij->i | out2,arg0->out0");
        Ok(())
    }

    #[test]
    fn brute_force_i_i_i() -> Result<()> {
        let path = Path::brute_force("i,i,i->")?;
        assert_eq!(path.len(), 1);
        assert_eq!(path[0].to_string(), "i,i,i-> | arg0,arg1,arg2->out0");
        Ok(())
    }
}