1use crate::*;
4use anyhow::Result;
5use std::collections::BTreeSet;
6
7#[derive(Debug, Clone, PartialEq, Eq)]
8pub struct Path {
9 original: Subscripts,
10 reduced_subscripts: Vec<Subscripts>,
11}
12
13impl std::ops::Deref for Path {
14 type Target = [Subscripts];
15 fn deref(&self) -> &[Subscripts] {
16 &self.reduced_subscripts
17 }
18}
19
20impl Path {
21 pub fn output(&self) -> &Subscript {
22 &self.original.output
23 }
24
25 pub fn num_args(&self) -> usize {
26 self.original.inputs.len()
27 }
28
29 pub fn compute_order(&self) -> usize {
30 compute_order(&self.reduced_subscripts)
31 }
32
33 pub fn memory_order(&self) -> usize {
34 memory_order(&self.reduced_subscripts)
35 }
36
37 pub fn brute_force(indices: &str) -> Result<Self> {
38 let mut names = Namespace::init();
39 let subscripts = Subscripts::from_raw_indices(&mut names, indices)?;
40 Ok(Path {
41 original: subscripts.clone(),
42 reduced_subscripts: brute_force_work(&mut names, subscripts)?,
43 })
44 }
45}
46
47fn compute_order(ss: &[Subscripts]) -> usize {
48 ss.iter()
49 .map(|ss| ss.compute_order())
50 .max()
51 .expect("self.0 never be empty")
52}
53
54fn memory_order(ss: &[Subscripts]) -> usize {
55 ss.iter()
56 .map(|ss| ss.memory_order())
57 .max()
58 .expect("self.0 never be empty")
59}
60
61fn brute_force_work(names: &mut Namespace, subscripts: Subscripts) -> Result<Vec<Subscripts>> {
62 if subscripts.inputs.len() <= 2 {
63 return Ok(vec![subscripts]);
65 }
66
67 let n = subscripts.inputs.len();
68 let mut subpaths = (0..2_usize.pow(n as u32))
69 .filter_map(|mut m| {
70 let mut pos = BTreeSet::new();
72 for i in 0..n {
73 if m % 2 == 1 {
74 pos.insert(*subscripts.inputs[i].position());
75 }
76 m /= 2;
77 }
78 if pos.len() >= 2 && pos.len() < n {
80 Some(pos)
81 } else {
82 None
83 }
84 })
85 .map(|pos| {
86 let mut names = names.clone();
87 let (inner, outer) = subscripts.factorize(&mut names, pos)?;
88 let mut sub = brute_force_work(&mut names, outer)?;
89 sub.insert(0, inner);
90 Ok(sub)
91 })
92 .collect::<Result<Vec<_>>>()?;
93 subpaths.push(vec![subscripts]);
94 Ok(subpaths
95 .into_iter()
96 .min_by_key(|path| (compute_order(path), memory_order(path)))
97 .expect("subpath never be empty"))
98}
99
100#[cfg(test)]
101mod test {
102 use super::*;
103
104 #[test]
105 fn brute_force_ij_jk() -> Result<()> {
106 let path = Path::brute_force("ij,jk->ik")?;
107 assert_eq!(path.len(), 1);
108 assert_eq!(path[0].to_string(), "ij,jk->ik | arg0,arg1->out0");
109 Ok(())
110 }
111
112 #[test]
113 fn brute_force_ij_jk_kl_l() -> Result<()> {
114 let path = Path::brute_force("ij,jk,kl,l->i")?;
115 assert_eq!(path.len(), 3);
116 assert_eq!(path[0].to_string(), "kl,l->k | arg2,arg3->out1");
117 assert_eq!(path[1].to_string(), "k,jk->j | out1,arg1->out2");
118 assert_eq!(path[2].to_string(), "j,ij->i | out2,arg0->out0");
119 Ok(())
120 }
121
122 #[test]
123 fn brute_force_i_i_i() -> Result<()> {
124 let path = Path::brute_force("i,i,i->")?;
125 assert_eq!(path.len(), 1);
126 assert_eq!(path[0].to_string(), "i,i,i-> | arg0,arg1,arg2->out0");
127 Ok(())
128 }
129}