tract_data/dim/
mod.rs

1//! Extended dimension support
2use crate::internal::*;
3use num_traits::Zero;
4use std::fmt;
5use std::ops;
6
7mod assertion;
8mod parse;
9mod resolve;
10mod sym;
11mod tree;
12
13pub use self::assertion::Assertion;
14pub use self::parse::parse_tdim;
15pub use self::resolve::solve_for;
16pub use self::sym::{Symbol, SymbolScope, SymbolValues};
17pub use self::tree::{TDim, TooEarly};
18
19use crate::{TractError, TractResult};
20
21/// A super-trait for value acting as tensor dimensions in tract.
22///
23/// Implemented by:
24///
25/// * `usize` for regular dimensions
26/// * `TDim` supporting regular and streaming dimensions
27pub trait DimLike:
28    Clone
29    + Default
30    + PartialEq
31    + From<usize>
32    + for<'a> std::convert::TryFrom<&'a TDim, Error = TractError>
33    + ::num_traits::Zero
34    + fmt::Debug
35    + fmt::Display
36    + std::hash::Hash
37    + ops::Add<Self, Output = Self>
38    + ops::Add<usize, Output = Self>
39    + for<'a> ops::Add<&'a Self, Output = Self>
40    + ops::Sub<Self, Output = Self>
41    + ops::Sub<usize, Output = Self>
42    + for<'a> ops::Sub<&'a Self, Output = Self>
43    + ops::Mul<Self, Output = Self>
44    + ops::Mul<usize, Output = Self>
45    + for<'a> ops::Mul<&'a Self, Output = Self>
46    + ops::Div<usize, Output = Self>
47    + ops::Rem<usize, Output = Self>
48    + Send
49    + Sync
50    + 'static
51    + std::iter::Sum
52    + std::iter::Product
53    + ToDim
54{
55    fn maybe_div(&self, other: &Self) -> TractResult<(Self, u64)>;
56
57    /// Integer divise, rounding up to next integer.
58    fn divceil(&self, other: usize) -> Self {
59        (self.clone() + other - 1) / other
60    }
61
62    /// Convert to regular integer.
63    fn to_i64(&self) -> TractResult<i64>;
64
65    fn to_usize(&self) -> TractResult<usize> {
66        self.to_i64().map(|d| d as usize)
67    }
68
69    fn to_isize(&self) -> TractResult<isize> {
70        self.to_i64().map(|d| d as isize)
71    }
72
73    fn to_i32(&self) -> TractResult<i32> {
74        self.to_i64().map(|d| d as i32)
75    }
76
77    /// do not use num_traits::Mul as it implies a regular Mul
78    fn one() -> Self;
79
80    /// Substitute as many symbols as possible in the dim value.
81    fn eval(&self, values: &SymbolValues) -> Self;
82
83    /// Full evaluation of the symbol, failing if a symbol is missing
84    fn eval_to_i64(&self, values: &SymbolValues) -> TractResult<i64>;
85
86    fn substitute(&self, from: &Symbol, to: &Self) -> TractResult<Self>;
87
88    fn broadcast(self, other: Self) -> TractResult<Self>;
89    fn mini(self, other: Self) -> Self;
90    fn maxi(self, other: Self) -> Self;
91
92    fn compatible_with(&self, other: &Self) -> bool;
93}
94
95impl DimLike for TDim {
96    fn maybe_div(&self, other: &Self) -> TractResult<(Self, u64)> {
97        if self.is_zero() {
98            return Ok((TDim::zero(), 1));
99        } else if other.is_zero() {
100            bail!("Division by zero")
101        }
102        fn expand(dim: &TDim) -> (i64, Vec<TDim>) {
103            match dim {
104                TDim::Mul(terms) => terms.iter().map(expand).fold((1i64, vec![]), |acc, t| {
105                    (acc.0 * t.0, acc.1.into_iter().chain(t.1).collect())
106                }),
107                TDim::MulInt(a, terms) => {
108                    let (b, v) = expand(terms);
109                    (a * b, v)
110                }
111                TDim::Val(x) => (*x, vec![]),
112                TDim::Add(terms) => {
113                    let gcd =
114                        terms.iter().map(expand).map(|(n, _)| n).reduce(|a, b| a.gcd(&b)).unwrap();
115                    (
116                        gcd,
117                        vec![TDim::Add(terms.iter().map(|t| t.clone() / gcd).collect()).simplify()],
118                    )
119                }
120                it => (1, vec![it.clone()]),
121            }
122        }
123        let (mut num_int, mut num) = expand(self);
124        let (mut denum_int, mut denum) = expand(other);
125        if num == denum {
126            num = vec![];
127            denum = vec![];
128        }
129        for it in denum {
130            if let Some(pos) = num.iter().position(|n| n == &it) {
131                num.remove(pos);
132            } else {
133                bail!("Can't divide {} by {}", self, other)
134            }
135        }
136        use num_integer::Integer;
137        if denum_int < 0 {
138            num_int *= -1;
139            denum_int *= -1;
140        }
141        let gcd = num_int.gcd(&denum_int);
142        num_int /= gcd;
143        denum_int /= gcd;
144        Ok(((TDim::Mul(num) * num_int).reduce(), denum_int as u64))
145    }
146
147    fn to_i64(&self) -> TractResult<i64> {
148        TDim::to_i64(self)
149    }
150
151    fn one() -> Self {
152        Self::from(1)
153    }
154
155    fn eval(&self, values: &SymbolValues) -> Self {
156        self.eval(values)
157    }
158
159    fn substitute(&self, from: &Symbol, to: &Self) -> TractResult<Self> {
160        self.substitute(from, to)
161    }
162
163    fn eval_to_i64(&self, values: &SymbolValues) -> TractResult<i64> {
164        TDim::eval_to_i64(self, values)
165    }
166
167    fn broadcast(self, other: Self) -> TractResult<Self> {
168        if self.is_one() {
169            Ok(other)
170        } else if other.is_one() {
171            Ok(self)
172        } else {
173            Ok(TDim::Broadcast(vec![self, other]).simplify())
174        }
175    }
176
177    fn compatible_with(&self, other: &Self) -> bool {
178        self.compatible_with(other)
179    }
180
181    fn mini(self, other: Self) -> Self {
182        TDim::Min(vec![self, other]).simplify()
183    }
184
185    fn maxi(self, other: Self) -> Self {
186        TDim::Max(vec![self, other]).simplify()
187    }
188}
189
190impl<'a> std::convert::TryFrom<&'a TDim> for TDim {
191    type Error = TractError;
192    fn try_from(d: &'a TDim) -> TractResult<TDim> {
193        Ok(d.clone())
194    }
195}
196
197impl DimLike for usize {
198    fn maybe_div(&self, other: &Self) -> TractResult<(Self, u64)> {
199        use num_integer::Integer;
200        let gcd = self.gcd(other);
201        Ok((self / gcd, (other / gcd) as u64))
202    }
203
204    fn to_i64(&self) -> TractResult<i64> {
205        Ok(*self as i64)
206    }
207
208    fn one() -> usize {
209        1
210    }
211
212    fn eval(&self, _values: &SymbolValues) -> Self {
213        *self
214    }
215
216    fn substitute(&self, _from: &Symbol, _to: &Self) -> TractResult<Self> {
217        Ok(*self)
218    }
219
220    fn eval_to_i64(&self, _: &SymbolValues) -> TractResult<i64> {
221        Ok(*self as i64)
222    }
223
224    fn broadcast(self, other: Self) -> TractResult<Self> {
225        if self == 1 || self == other {
226            Ok(other)
227        } else if other == 1 {
228            Ok(self)
229        } else {
230            bail!("Can not broadcast {self} against {other}")
231        }
232    }
233
234    fn compatible_with(&self, other: &Self) -> bool {
235        self == other
236    }
237
238    fn mini(self, other: Self) -> Self {
239        if self < other {
240            self
241        } else {
242            other
243        }
244    }
245
246    fn maxi(self, other: Self) -> Self {
247        if self > other {
248            self
249        } else {
250            other
251        }
252    }
253}
254
255impl<'a> std::convert::TryFrom<&'a TDim> for usize {
256    type Error = TractError;
257    fn try_from(d: &'a TDim) -> TractResult<usize> {
258        d.to_usize()
259    }
260}
261
262/// Convenience trait to convert values to TDim.
263pub trait ToDim {
264    /// Convert self to a TDim.
265    fn to_dim(&self) -> TDim;
266}
267
268impl<I: Into<TDim> + Clone> ToDim for I {
269    fn to_dim(&self) -> TDim {
270        self.clone().into()
271    }
272}
273
274impl ToDim for &TDim {
275    fn to_dim(&self) -> TDim {
276        (*self).clone()
277    }
278}
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283
284    lazy_static::lazy_static! {
285        static ref S: (SymbolScope, Symbol) = {
286            let table = SymbolScope::default();
287            let s = table.new_with_prefix("S");
288            (table, s)
289        };
290    }
291
292    pub fn s() -> TDim {
293        S.1.clone().into()
294    }
295
296    #[test]
297    fn div() {
298        assert_eq!(TDim::from(12).maybe_div(&TDim::from(4)).unwrap(), (3.into(), 1));
299    }
300
301    #[test]
302    fn div_sym_int() {
303        assert_eq!((s() * 12).maybe_div(&TDim::from(4)).unwrap(), (s() * 3, 1));
304    }
305
306    #[test]
307    fn div_sym_sym() {
308        assert_eq!((s() * 12).maybe_div(&(s() * 4)).unwrap(), (3.into(), 1));
309    }
310
311    #[test]
312    fn div_sym_sym_ratio() {
313        assert_eq!((s() * 13).maybe_div(&(s() * 4)).unwrap(), (13.into(), 4));
314    }
315
316    #[test]
317    fn div_sym_sym_rem() {
318        assert!((s() + 1).maybe_div(&(s() * 4)).is_err());
319    }
320
321    #[test]
322    fn div_sym_sym_simply_1() {
323        assert_eq!((s()).maybe_div(&(s())).unwrap(), (TDim::Val(1), 1));
324    }
325
326    #[test]
327    fn div_sym_sym_complex() {
328        let s = s();
329        let b = S.0.sym("b");
330        assert_eq!(
331            (256.to_dim() * &s * &b).maybe_div(&(1.to_dim() * &s * &b)).unwrap(),
332            (256.into(), 1)
333        );
334    }
335
336    #[test]
337    fn div_sym_sym_with_add() {
338        assert_eq!((s() * 80 - 160).maybe_div(&(s() - 2)).unwrap(), (80.into(), 1));
339    }
340}