tensorflow/
expr.rs

1//! This module builds computation graphs.
2//!
3//! This module is unfinished.
4#![cfg(feature = "tensorflow_unstable")]
5
6use super::Graph;
7use super::Operation;
8use super::Shape;
9use super::Status;
10use super::Tensor;
11use super::TensorType;
12use std::cmp::Eq;
13use std::collections::HashMap;
14use std::convert::From;
15use std::fmt::Debug;
16use std::fmt::Display;
17use std::fmt::Error;
18use std::fmt::Formatter;
19use std::hash::Hash;
20use std::hash::Hasher;
21use std::marker::PhantomData;
22use std::ops;
23use std::rc::Rc;
24
25/// Denotes operator precedence.
26/// Used for displaying expressions as strings.
27#[derive(Ord, PartialOrd, Eq, PartialEq, Debug, Copy, Clone)]
28pub enum OpLevel {
29    /// Assignment.
30    Assign,
31
32    /// Addition and subtraction.
33    Add,
34
35    /// Multiplication, division, and remainder.
36    Mul,
37
38    /// Unary operators like negation.
39    Unary,
40
41    /// Variables and constants.
42    Atom,
43}
44
45////////////////////////
46
47/// A operation in an expression tree, which is a thin wrapper around an ExprImpl.
48///
49/// This is separate from ExprImpl because we want expressions to be wrapped in an Rc,
50/// and we can't directly implement std::ops::Add, etc., for Rc<E: ExprImpl<Tgt;>.
51#[derive(Debug, Clone)]
52pub struct Expr<T: TensorType> {
53    expr: Rc<dyn ExprImpl<T>>,
54}
55
56impl<T: TensorType> Expr<T> {
57    /// Wraps an ExprImpl.
58    pub fn new<I>(expr: I) -> Expr<T>
59    where
60        I: ExprImpl<T> + 'static,
61    {
62        Expr {
63            expr: Rc::new(expr),
64        }
65    }
66}
67
68impl<T: TensorType> ops::Deref for Expr<T> {
69    type Target = dyn ExprImpl<T>;
70
71    fn deref(&self) -> &Self::Target {
72        self.expr.deref()
73    }
74}
75
76impl<T: TensorType> Display for Expr<T> {
77    fn fmt(&self, f: &mut Formatter) -> Result<(), Error> {
78        Display::fmt(&self.expr, f)
79    }
80}
81
82impl<T: TensorType> From<T> for Expr<T> {
83    fn from(value: T) -> Self {
84        Expr::new(value)
85    }
86}
87
88////////////////////////
89
90/// Enum of an expr's possible shape states
91#[derive(Debug)]
92pub enum ShapeHint<'a> {
93    /// Unknown shape
94    Unknown,
95
96    /// Well defined shape that exactly matches contained value
97    Exactly(&'a [u64]),
98}
99
100////////////////////////
101
102/// Trait implemented by all expression types.
103/// Most users will want to store an Expr instead.
104pub trait ExprImpl<T: TensorType>: Display + Debug {
105    /// Returns the precedence level for this operator.
106    fn op_level(&self) -> OpLevel;
107
108    /// Returns the child expressions.
109    ///
110    /// For example, the child expressions of `x + y` would be `x` and `y`.
111    fn children(&self) -> Vec<Box<dyn AnyExpr>>; // TODO: return an iterator
112
113    /// Creates an operation for the expression.
114    ///
115    /// The implementation must use the operations in the `children` parameter
116    /// rather than creating child operations itself.
117    fn create_operation(
118        &self,
119        graph: &mut Graph,
120        children: &[Operation],
121        id_gen: &mut dyn FnMut() -> String,
122    ) -> Result<Operation, Status>;
123
124    /// Returns the derivative of the expression with respect to the given variable.
125    fn derivative_by_variable(&self, var: &str) -> Result<Expr<T>, Status>;
126
127    /// Returns a hint about the expression's shape.
128    fn shape_hint(&self) -> ShapeHint {
129        ShapeHint::Unknown
130    }
131}
132
133impl<T: TensorType> ExprImpl<T> for T {
134    fn op_level(&self) -> OpLevel {
135        OpLevel::Atom
136    }
137
138    fn children(&self) -> Vec<Box<dyn AnyExpr>> {
139        vec![]
140    }
141
142    fn create_operation(
143        &self,
144        graph: &mut Graph,
145        _children: &[Operation],
146        id_gen: &mut dyn FnMut() -> String,
147    ) -> Result<Operation, Status> {
148        let mut nd = graph.new_operation("Const", &id_gen())?;
149        nd.set_attr_type("dtype", T::data_type())?;
150        let mut value = Tensor::new(&[1]);
151        value[0] = self.clone();
152        nd.set_attr_tensor("value", value)?;
153        nd.finish()
154    }
155
156    fn derivative_by_variable(&self, _var: &str) -> Result<Expr<T>, Status> {
157        Ok(Expr::from(T::zero()))
158    }
159}
160
161////////////////////////
162
163macro_rules! impl_bin_op {
164  ($name:ident, $fn_name:ident, $op:expr, $op_level:ident, $assoc:expr,
165      $tf_op:expr, $doc:expr, $($ximpl:tt)*) => {
166    #[doc = $doc]
167    #[derive(Debug)]
168    pub struct $name<T: TensorType> {
169      left: Expr<T>,
170      right: Expr<T>,
171    }
172
173    impl<T: TensorType> ops::$name for Expr<T> {
174      type Output = Expr<T>;
175
176      fn $fn_name(self, rhs: Expr<T>) -> Expr<T> {
177        Expr::new($name {
178            left: self,
179            right: rhs,
180        })
181      }
182    }
183
184    impl<T: TensorType> ops::$name<T> for Expr<T> {
185      type Output = Expr<T>;
186
187      fn $fn_name(self, rhs: T) -> Expr<T> {
188        Expr::new($name {
189            left: self,
190            right: Expr::from(rhs),
191        })
192      }
193    }
194
195    impl<T: TensorType> Display for $name<T> {
196      fn fmt(&self, f: &mut Formatter) -> Result<(), Error> {
197        if self.left.op_level() < OpLevel::$op_level {
198          write!(f, "({})", self.left)?;
199        } else {
200          write!(f, "{}", self.left)?;
201        }
202        write!(f, concat!(" ", $op, " "))?;
203        let paren = if $assoc {
204          self.right.op_level() < OpLevel::$op_level
205        } else {
206          self.right.op_level() <= OpLevel::$op_level
207        };
208        if paren {
209          write!(f, "({})", self.right)
210        } else {
211          write!(f, "{}", self.right)
212        }
213      }
214    }
215
216    impl<T: TensorType> ExprImpl<T> for $name<T> {
217      fn op_level(&self) -> OpLevel {
218        OpLevel::$op_level
219      }
220
221      fn children(&self) -> Vec<Box<dyn AnyExpr>> {
222        vec![Box::new(self.left.clone()), Box::new(self.right.clone())]
223      }
224
225      fn create_operation(&self, graph: &mut Graph, children: &[Operation],
226          id_gen: &mut dyn FnMut() -> String) -> Result<Operation, Status> {
227        let mut nd = graph.new_operation($tf_op, &id_gen())?;
228        nd.add_input(children[0].clone());
229        nd.add_input(children[1].clone());
230        nd.finish()
231      }
232
233      $($ximpl)*
234    }
235  }
236}
237
238impl_bin_op!(
239    Add,
240    add,
241    "+",
242    Add,
243    true,
244    "Add",
245    "Expression resulting from adding two subexpressions.",
246    fn derivative_by_variable(&self, var: &str) -> Result<Expr<T>, Status> {
247        Ok(self.left.derivative_by_variable(var)? + self.right.derivative_by_variable(var)?)
248    }
249);
250impl_bin_op!(
251    Sub,
252    sub,
253    "-",
254    Add,
255    false,
256    "Sub",
257    "Expression resulting from subtracting two subexpressions.",
258    fn derivative_by_variable(&self, var: &str) -> Result<Expr<T>, Status> {
259        Ok(self.left.derivative_by_variable(var)? - self.right.derivative_by_variable(var)?)
260    }
261);
262impl_bin_op!(
263    Mul,
264    mul,
265    "*",
266    Mul,
267    true,
268    "Mul",
269    "Expression resulting from multiplying two subexpressions.",
270    fn derivative_by_variable(&self, var: &str) -> Result<Expr<T>, Status> {
271        Ok(self.left.derivative_by_variable(var)? * self.right.clone()
272            + self.left.clone() * self.right.derivative_by_variable(var)?)
273    }
274);
275impl_bin_op!(
276    Div,
277    div,
278    "/",
279    Mul,
280    false,
281    "Div",
282    "Expression resulting from dividing two subexpressions.",
283    fn derivative_by_variable(&self, var: &str) -> Result<Expr<T>, Status> {
284        let num = self.left.derivative_by_variable(var)? * self.right.clone()
285            - self.left.clone() * self.right.derivative_by_variable(var)?;
286        let denom = self.right.clone() * self.right.clone();
287        Ok(num / denom)
288    }
289);
290impl_bin_op!(
291    Rem,
292    rem,
293    "%",
294    Mul,
295    false,
296    "Mod",
297    "Expression resulting from taking a modulus.",
298    fn derivative_by_variable(&self, var: &str) -> Result<Expr<T>, Status> {
299        Ok(self.left.derivative_by_variable(var)?
300            - TruncateDiv::new_expr(self.left.clone(), self.right.clone())
301                * self.right.derivative_by_variable(var)?)
302    }
303);
304
305////////////////////////
306
307/// Expression that assigns a value to a variable.
308#[derive(Debug)]
309pub struct TruncateDiv<T: TensorType> {
310    left: Expr<T>,
311    right: Expr<T>,
312}
313
314impl<T: TensorType> TruncateDiv<T> {
315    fn new(left: Expr<T>, right: Expr<T>) -> Self {
316        TruncateDiv { left, right }
317    }
318
319    /// Creates an expression that divides `left` by `right` and rounds toward zero.
320    pub fn new_expr(left: Expr<T>, right: Expr<T>) -> Expr<T> {
321        Expr::new(TruncateDiv::new(left, right))
322    }
323}
324
325impl<T: TensorType> Display for TruncateDiv<T> {
326    fn fmt(&self, f: &mut Formatter) -> Result<(), Error> {
327        write!(f, "{} // {}", self.left, self.right)
328    }
329}
330
331impl<T: TensorType> ExprImpl<T> for TruncateDiv<T> {
332    fn op_level(&self) -> OpLevel {
333        OpLevel::Mul
334    }
335
336    fn children(&self) -> Vec<Box<dyn AnyExpr>> {
337        vec![Box::new(self.left.clone()), Box::new(self.right.clone())]
338    }
339
340    fn create_operation(
341        &self,
342        graph: &mut Graph,
343        children: &[Operation],
344        id_gen: &mut dyn FnMut() -> String,
345    ) -> Result<Operation, Status> {
346        let mut nd = graph.new_operation("TruncateDiv", &id_gen())?;
347        nd.add_input(children[0].clone());
348        nd.add_input(children[1].clone());
349        nd.finish()
350    }
351
352    fn derivative_by_variable(&self, var: &str) -> Result<Expr<T>, Status> {
353        // Mod(x, y) = x - TruncateDiv(x, y) * y
354        // TruncateDiv(x, y) = (x - Mod(x, y)) / y
355        // d/dt TruncateDiv(x, y) = (y * d/dt (x - Mod(x, y)) - (x - Mod(x, y)) dy/dt) / (y * y)
356        let diff = self.left.clone() - self.left.clone() % self.right.clone();
357        let term1 = self.right.clone() * diff.derivative_by_variable(var)?;
358        let term2 = diff * self.right.derivative_by_variable(var)?;
359        Ok((term1 - term2) / (self.right.clone() * self.right.clone()))
360    }
361}
362
363////////////////////////
364
365/// Expression resulting from negation of an expression.
366#[derive(Debug)]
367pub struct Neg<T: TensorType> {
368    expr: Expr<T>,
369}
370
371impl<T: TensorType> ops::Neg for Expr<T> {
372    type Output = Expr<T>;
373
374    fn neg(self) -> Expr<T> {
375        Expr::new(Neg { expr: self })
376    }
377}
378
379impl<T: TensorType> Display for Neg<T> {
380    fn fmt(&self, f: &mut Formatter) -> Result<(), Error> {
381        write!(f, "-")?;
382        if self.expr.op_level() <= OpLevel::Unary {
383            write!(f, "({})", self.expr)
384        } else {
385            write!(f, "{}", self.expr)
386        }
387    }
388}
389
390impl<T: TensorType> ExprImpl<T> for Neg<T> {
391    fn op_level(&self) -> OpLevel {
392        OpLevel::Unary
393    }
394
395    fn children(&self) -> Vec<Box<dyn AnyExpr>> {
396        vec![Box::new(self.expr.clone())]
397    }
398
399    fn create_operation(
400        &self,
401        graph: &mut Graph,
402        children: &[Operation],
403        id_gen: &mut dyn FnMut() -> String,
404    ) -> Result<Operation, Status> {
405        let mut nd = graph.new_operation("Neg", &id_gen())?;
406        nd.add_input(children[0].clone());
407        nd.finish()
408    }
409
410    fn derivative_by_variable(&self, var: &str) -> Result<Expr<T>, Status> {
411        Ok(-self.expr.derivative_by_variable(var)?)
412    }
413}
414
415////////////////////////
416
417/// Expression for a variable.
418#[derive(Debug)]
419pub struct Variable<T: TensorType> {
420    shape: Vec<u64>,
421    name: String,
422    phantom: PhantomData<T>,
423}
424
425impl<T: TensorType> Variable<T> {
426    fn new(shape: &[u64], name: &str) -> Self {
427        Variable {
428            shape: Vec::from(shape),
429            name: name.to_string(),
430            phantom: PhantomData,
431        }
432    }
433
434    /// Creates an `Expr` for a variable.
435    pub fn new_expr(shape: &[u64], name: &str) -> Expr<T> {
436        Expr::new(Variable::new(shape, name))
437    }
438}
439
440impl<T: TensorType> Display for Variable<T> {
441    fn fmt(&self, f: &mut Formatter) -> Result<(), Error> {
442        write!(f, "{}", self.name)
443    }
444}
445
446impl<T: TensorType> ExprImpl<T> for Variable<T> {
447    fn op_level(&self) -> OpLevel {
448        OpLevel::Atom
449    }
450
451    fn children(&self) -> Vec<Box<dyn AnyExpr>> {
452        vec![]
453    }
454
455    fn create_operation(
456        &self,
457        graph: &mut Graph,
458        _children: &[Operation],
459        _id_gen: &mut dyn FnMut() -> String,
460    ) -> Result<Operation, Status> {
461        let mut nd = graph.new_operation("Variable", &self.name)?;
462        let shape = self
463            .shape
464            .iter()
465            .map(|dim_size| Some(*dim_size as i64))
466            .collect();
467
468        nd.set_attr_type("dtype", T::data_type()).unwrap();
469        nd.set_attr_shape("shape", &Shape(Some(shape))).unwrap();
470        nd.finish()
471    }
472
473    fn derivative_by_variable(&self, var: &str) -> Result<Expr<T>, Status> {
474        Ok(if var == self.name {
475            Expr::from(T::one())
476        } else {
477            Expr::from(T::zero())
478        })
479    }
480
481    fn shape_hint(&self) -> ShapeHint {
482        ShapeHint::Exactly(&self.shape)
483    }
484}
485
486////////////////////////
487
488/// Expression for a placeholder.
489#[derive(Debug)]
490pub struct Placeholder<T: TensorType> {
491    shape: Vec<u64>,
492    name: String,
493    phantom: PhantomData<T>,
494}
495
496impl<T: TensorType> Placeholder<T> {
497    fn new(shape: &[u64], name: &str) -> Self {
498        Placeholder {
499            shape: Vec::from(shape),
500            name: name.to_string(),
501            phantom: PhantomData,
502        }
503    }
504
505    /// Creates an `Expr` for a placeholder.
506    pub fn new_expr(shape: &[u64], name: &str) -> Expr<T> {
507        Expr::new(Placeholder::new(shape, name))
508    }
509}
510
511impl<T: TensorType> Display for Placeholder<T> {
512    fn fmt(&self, f: &mut Formatter) -> Result<(), Error> {
513        write!(f, "{}", self.name)
514    }
515}
516
517impl<T: TensorType> ExprImpl<T> for Placeholder<T> {
518    fn op_level(&self) -> OpLevel {
519        OpLevel::Atom
520    }
521
522    fn children(&self) -> Vec<Box<dyn AnyExpr>> {
523        vec![]
524    }
525
526    fn create_operation(
527        &self,
528        graph: &mut Graph,
529        _children: &[Operation],
530        _id_gen: &mut dyn FnMut() -> String,
531    ) -> Result<Operation, Status> {
532        let mut nd = graph.new_operation("Placeholder", &self.name)?;
533        let shape = self
534            .shape
535            .iter()
536            .map(|dim_size| Some(*dim_size as i64))
537            .collect();
538
539        nd.set_attr_type("dtype", T::data_type()).unwrap();
540        nd.set_attr_shape("shape", &Shape(Some(shape))).unwrap();
541        nd.finish()
542    }
543
544    fn derivative_by_variable(&self, _var: &str) -> Result<Expr<T>, Status> {
545        Ok(Expr::from(T::zero()))
546    }
547
548    fn shape_hint(&self) -> ShapeHint {
549        ShapeHint::Exactly(&self.shape)
550    }
551}
552
553////////////////////////
554
555/// Expression for a constant.
556#[derive(Debug)]
557pub struct Constant<T: TensorType> {
558    tensor: Tensor<T>,
559}
560
561impl<T: TensorType> Constant<T> {
562    /// Creates a constant with the given value.
563    pub fn new(tensor: Tensor<T>) -> Self {
564        Constant { tensor }
565    }
566
567    /// Creates a constant with the given value.
568    pub fn new_expr(tensor: Tensor<T>) -> Expr<T> {
569        Expr::new(Constant { tensor })
570    }
571}
572
573impl<T: TensorType> Display for Constant<T> {
574    fn fmt(&self, f: &mut Formatter) -> Result<(), Error> {
575        write!(f, "{}", self.tensor)
576    }
577}
578
579impl<T: TensorType> ExprImpl<T> for Constant<T> {
580    fn op_level(&self) -> OpLevel {
581        OpLevel::Atom
582    }
583
584    fn children(&self) -> Vec<Box<dyn AnyExpr>> {
585        vec![]
586    }
587
588    fn create_operation(
589        &self,
590        graph: &mut Graph,
591        _children: &[Operation],
592        id_gen: &mut dyn FnMut() -> String,
593    ) -> Result<Operation, Status> {
594        let mut nd = graph.new_operation("Const", &id_gen())?;
595
596        nd.set_attr_type("dtype", T::data_type())?;
597        nd.set_attr_tensor("value", self.tensor.clone())?;
598        nd.finish()
599    }
600
601    fn derivative_by_variable(&self, _var: &str) -> Result<Expr<T>, Status> {
602        Ok(Expr::from(T::zero()))
603    }
604}
605
606////////////////////////
607
608/// Expression that assigns a value to a variable.
609#[derive(Debug)]
610pub struct Assign<T: TensorType> {
611    variable: Expr<T>,
612    value: Expr<T>,
613}
614
615impl<T: TensorType> Assign<T> {
616    fn new(variable: Expr<T>, value: Expr<T>) -> Self {
617        Assign { variable, value }
618    }
619
620    /// Creates an expression that assigns `value` to `variable`.
621    pub fn new_expr(variable: Expr<T>, value: Expr<T>) -> Expr<T> {
622        Expr::new(Assign::new(variable, value))
623    }
624
625    /// Creates an expression that takes values from `iterable` to fill `variable`.
626    pub fn to(variable: Expr<T>, iterable: impl Iterator<Item = T>) -> crate::Result<Expr<T>> {
627        let constant = if let ShapeHint::Exactly(shape) = variable.expr.shape_hint() {
628            let values: Vec<_> = iterable
629                .take(shape.iter().product::<u64>() as usize)
630                .collect();
631
632            Constant::new_expr(Tensor::new(shape).with_values(&values)?)
633        } else {
634            return Err(invalid_arg!(
635                "Cannot assign to expression {} with unknown size!",
636                variable
637            ));
638        };
639
640        Ok(Assign::new_expr(variable, constant))
641    }
642}
643
644impl<T: TensorType> Display for Assign<T> {
645    fn fmt(&self, f: &mut Formatter) -> Result<(), Error> {
646        write!(f, "{} = {}", self.variable, self.value)
647    }
648}
649
650impl<T: TensorType> ExprImpl<T> for Assign<T> {
651    fn op_level(&self) -> OpLevel {
652        OpLevel::Assign
653    }
654
655    fn children(&self) -> Vec<Box<dyn AnyExpr>> {
656        vec![
657            Box::new(self.variable.clone()),
658            Box::new(self.value.clone()),
659        ]
660    }
661
662    fn create_operation(
663        &self,
664        graph: &mut Graph,
665        children: &[Operation],
666        id_gen: &mut dyn FnMut() -> String,
667    ) -> Result<Operation, Status> {
668        let mut nd = graph.new_operation("Assign", &id_gen())?;
669        nd.add_input(children[0].clone());
670        nd.add_input(children[1].clone());
671        nd.finish()
672    }
673
674    fn derivative_by_variable(&self, _var: &str) -> Result<Expr<T>, Status> {
675        Err(invalid_arg!("Cannot take the derivative of an assignment"))
676    }
677}
678
679////////////////////////
680
681// TODO: See if we can make this private.
682/// An `AnyExpr` is just an `Expr<T>` for some unknown `T`.
683/// Clients *should not* implement this.
684pub trait AnyExpr: Debug {
685    /// Returns a pointer usable as a map key which identifies this expression.
686    fn key(&self) -> *const ();
687
688    /// Returns the child expressions.
689    ///
690    /// For example, the child expressions of `x + y` would be `x` and `y`.
691    fn children(&self) -> Vec<Box<dyn AnyExpr>>; // TODO: return an iterator
692
693    /// Creates an operation for the expression.
694    ///
695    /// The implementation must use the operations in the `children` parameter
696    /// rather than creating child operations itself.
697    fn create_operation(
698        &self,
699        graph: &mut Graph,
700        children: &[Operation],
701        id_gen: &mut dyn FnMut() -> String,
702    ) -> Result<Operation, Status>;
703
704    /// Returns a boxed clone.
705    ///
706    /// This is used rather than the `Clone` trait because that would prevent
707    /// `AnyExpr` values from being used as trait objects.
708    fn clone_box(&self) -> Box<dyn AnyExpr>;
709}
710
711impl<T: TensorType> AnyExpr for Expr<T> {
712    #[allow(trivial_casts)]
713    fn key(&self) -> *const () {
714        self.expr.as_ref() as *const dyn ExprImpl<T> as *const ()
715    }
716
717    fn children(&self) -> Vec<Box<dyn AnyExpr>> {
718        self.expr.children()
719    }
720
721    fn create_operation(
722        &self,
723        graph: &mut Graph,
724        children: &[Operation],
725        id_gen: &mut dyn FnMut() -> String,
726    ) -> Result<Operation, Status> {
727        self.expr.create_operation(graph, children, id_gen)
728    }
729
730    fn clone_box(&self) -> Box<dyn AnyExpr> {
731        Box::new(self.clone())
732    }
733}
734
735#[derive(Debug)]
736struct Key(Box<dyn AnyExpr>);
737
738impl PartialEq for Key {
739    fn eq(&self, other: &Key) -> bool {
740        self.0.key() == other.0.key()
741    }
742}
743
744impl Eq for Key {}
745
746impl Hash for Key {
747    fn hash<H>(&self, state: &mut H)
748    where
749        H: Hasher,
750    {
751        state.write_isize(self.0.key() as isize)
752    }
753}
754
755/// A `Compiler` compiles `Expr`s to `Operation`s.
756#[derive(Debug)]
757pub struct Compiler<'l> {
758    graph: &'l mut Graph,
759    operations: HashMap<Key, Operation>,
760    next_id: i32,
761}
762
763impl<'l> Compiler<'l> {
764    /// Creates a compiler for the given graph.
765    pub fn new(graph: &'l mut Graph) -> Self {
766        Compiler {
767            graph,
768            operations: HashMap::new(),
769            next_id: 0,
770        }
771    }
772
773    /// Compiles the expression.
774    pub fn compile<T: TensorType>(&mut self, expr: Expr<T>) -> Result<Operation, Status> {
775        self.compile_any(Box::new(expr))
776    }
777
778    /// Compiles the expression.
779    pub fn compile_any(&mut self, expr: Box<dyn AnyExpr>) -> Result<Operation, Status> {
780        let mut child_operations = vec![];
781        for child in expr.children() {
782            let key = Key(child.clone_box());
783            // The result is mapped separately from the match statement below to avoid
784            // reference lifetime isues.
785            let value = self.operations.get(&key).cloned();
786            child_operations.push(match value {
787                Some(v) => v,
788                None => self.compile_any(child)?,
789            });
790        }
791        let mut next_id = self.next_id;
792        let result = expr.create_operation(self.graph, &child_operations, &mut || {
793            let id = format!("operation_{}", next_id);
794            next_id += 1;
795            id
796        });
797        self.next_id = next_id;
798        let operation = result?;
799        self.operations.insert(Key(expr), operation.clone());
800        Ok(operation)
801    }
802}
803
804////////////////////////
805
806#[cfg(test)]
807mod tests {
808    use super::super::Graph;
809    use super::*;
810
811    #[test]
812    fn test_display() {
813        assert_eq!("1 + 2 + 3", format!("{}", (Expr::from(1) + 2) + 3));
814        assert_eq!(
815            "1 + 2 + 3",
816            format!("{}", Expr::from(1) + (Expr::from(2) + 3))
817        );
818        assert_eq!("1 + 2 - 3", format!("{}", (Expr::from(1) + 2) - 3));
819        assert_eq!(
820            "1 - (2 + 3)",
821            format!("{}", Expr::from(1) - (Expr::from(2) + 3))
822        );
823
824        assert_eq!("(1 + 2) * 3", format!("{}", (Expr::from(1) + 2) * 3));
825        assert_eq!(
826            "1 * (2 + 3)",
827            format!("{}", Expr::from(1) * (Expr::from(2) + 3))
828        );
829        assert_eq!("1 * 2 * 3", format!("{}", (Expr::from(1) * 2) * 3));
830        assert_eq!(
831            "1 * 2 * 3",
832            format!("{}", Expr::from(1) * (Expr::from(2) * 3))
833        );
834
835        assert_eq!("(1 + 2) / 3", format!("{}", (Expr::from(1) + 2) / 3));
836        assert_eq!(
837            "1 / (2 + 3)",
838            format!("{}", Expr::from(1) / (Expr::from(2) + 3))
839        );
840        assert_eq!("1 * 2 / 3", format!("{}", (Expr::from(1) * 2) / 3));
841        assert_eq!(
842            "1 / (2 * 3)",
843            format!("{}", Expr::from(1) / (Expr::from(2) * 3))
844        );
845
846        assert_eq!("(1 + 2) % 3", format!("{}", (Expr::from(1) + 2) % 3));
847        assert_eq!(
848            "1 % (2 + 3)",
849            format!("{}", Expr::from(1) % (Expr::from(2) + 3))
850        );
851        assert_eq!("1 * 2 % 3", format!("{}", (Expr::from(1) * 2) % 3));
852        assert_eq!(
853            "1 % (2 * 3)",
854            format!("{}", Expr::from(1) % (Expr::from(2) * 3))
855        );
856
857        assert_eq!("-1", format!("{}", -Expr::from(1)));
858        assert_eq!("-(-1)", format!("{}", -(-Expr::from(1))));
859        assert_eq!("-(1 + 2)", format!("{}", -(Expr::from(1) + 2)));
860
861        assert_eq!("x", format!("{}", <Variable<f32>>::new(&vec![2, 3], "x")));
862
863        assert_eq!(
864            "x",
865            format!("{}", <Placeholder<f32>>::new(&vec![2, 3], "x"))
866        );
867
868        assert_eq!(
869            "x = 1 + 2",
870            format!(
871                "{}",
872                Assign::new(
873                    <Placeholder<f32>>::new_expr(&vec![2, 3], "x"),
874                    Expr::from(1.0f32) + 2.0f32
875                )
876            )
877        );
878    }
879
880    #[test]
881    fn test_compile() {
882        let mut g = Graph::new();
883
884        let x = <Placeholder<f32>>::new_expr(&vec![2, 3], "x");
885        let w = <Variable<f32>>::new_expr(&vec![2, 3], "w");
886
887        let mut compiler = Compiler::new(&mut g);
888
889        compiler
890            .compile(x * w.clone() / w.clone() % w.clone() + w.clone() - w.clone())
891            .unwrap();
892
893        compiler
894            .compile(Assign::to(w, ::std::iter::repeat(1.)).unwrap())
895            .unwrap();
896    }
897
898    #[test]
899    fn test_derivative_by_variable() {
900        let x = <Variable<f32>>::new_expr(&[], "x");
901        let y = <Variable<f32>>::new_expr(&[], "y");
902        for &(ref expected, ref expression) in [
903            ("0", Expr::from(1.0f32)),
904            ("1", x.clone()),
905            ("0", y.clone()),
906            ("1 + 0", x.clone() + y.clone()),
907            ("1 - 0", x.clone() - y.clone()),
908            ("1 * x + x * 1", x.clone() * x.clone()),
909            ("1 * y + x * 0", x.clone() * y.clone()),
910            (
911                "(1 * x + x * 1) * x + x * x * 1",
912                x.clone() * x.clone() * x.clone(),
913            ),
914            ("(1 * y - x * 0) / (y * y)", x.clone() / y.clone()),
915            ("1 - x // y * 0", x.clone() % y.clone()),
916            ("0 - y // x * 1", y.clone() % x.clone()),
917            (
918                "(y * (1 - (1 - x // y * 0)) - (x - x % y) * 0) / (y * y)",
919                TruncateDiv::new_expr(x.clone(), y.clone()),
920            ),
921        ]
922        .iter()
923        {
924            assert_eq!(
925                *expected,
926                format!("{}", expression.derivative_by_variable("x").unwrap())
927            );
928        }
929    }
930}