1#![cfg(feature = "tensorflow_unstable")]
5
6use super::Graph;
7use super::Operation;
8use super::Shape;
9use super::Status;
10use super::Tensor;
11use super::TensorType;
12use std::cmp::Eq;
13use std::collections::HashMap;
14use std::convert::From;
15use std::fmt::Debug;
16use std::fmt::Display;
17use std::fmt::Error;
18use std::fmt::Formatter;
19use std::hash::Hash;
20use std::hash::Hasher;
21use std::marker::PhantomData;
22use std::ops;
23use std::rc::Rc;
24
25#[derive(Ord, PartialOrd, Eq, PartialEq, Debug, Copy, Clone)]
28pub enum OpLevel {
29 Assign,
31
32 Add,
34
35 Mul,
37
38 Unary,
40
41 Atom,
43}
44
45#[derive(Debug, Clone)]
52pub struct Expr<T: TensorType> {
53 expr: Rc<dyn ExprImpl<T>>,
54}
55
56impl<T: TensorType> Expr<T> {
57 pub fn new<I>(expr: I) -> Expr<T>
59 where
60 I: ExprImpl<T> + 'static,
61 {
62 Expr {
63 expr: Rc::new(expr),
64 }
65 }
66}
67
68impl<T: TensorType> ops::Deref for Expr<T> {
69 type Target = dyn ExprImpl<T>;
70
71 fn deref(&self) -> &Self::Target {
72 self.expr.deref()
73 }
74}
75
76impl<T: TensorType> Display for Expr<T> {
77 fn fmt(&self, f: &mut Formatter) -> Result<(), Error> {
78 Display::fmt(&self.expr, f)
79 }
80}
81
82impl<T: TensorType> From<T> for Expr<T> {
83 fn from(value: T) -> Self {
84 Expr::new(value)
85 }
86}
87
88#[derive(Debug)]
92pub enum ShapeHint<'a> {
93 Unknown,
95
96 Exactly(&'a [u64]),
98}
99
100pub trait ExprImpl<T: TensorType>: Display + Debug {
105 fn op_level(&self) -> OpLevel;
107
108 fn children(&self) -> Vec<Box<dyn AnyExpr>>; fn create_operation(
118 &self,
119 graph: &mut Graph,
120 children: &[Operation],
121 id_gen: &mut dyn FnMut() -> String,
122 ) -> Result<Operation, Status>;
123
124 fn derivative_by_variable(&self, var: &str) -> Result<Expr<T>, Status>;
126
127 fn shape_hint(&self) -> ShapeHint {
129 ShapeHint::Unknown
130 }
131}
132
133impl<T: TensorType> ExprImpl<T> for T {
134 fn op_level(&self) -> OpLevel {
135 OpLevel::Atom
136 }
137
138 fn children(&self) -> Vec<Box<dyn AnyExpr>> {
139 vec![]
140 }
141
142 fn create_operation(
143 &self,
144 graph: &mut Graph,
145 _children: &[Operation],
146 id_gen: &mut dyn FnMut() -> String,
147 ) -> Result<Operation, Status> {
148 let mut nd = graph.new_operation("Const", &id_gen())?;
149 nd.set_attr_type("dtype", T::data_type())?;
150 let mut value = Tensor::new(&[1]);
151 value[0] = self.clone();
152 nd.set_attr_tensor("value", value)?;
153 nd.finish()
154 }
155
156 fn derivative_by_variable(&self, _var: &str) -> Result<Expr<T>, Status> {
157 Ok(Expr::from(T::zero()))
158 }
159}
160
161macro_rules! impl_bin_op {
164 ($name:ident, $fn_name:ident, $op:expr, $op_level:ident, $assoc:expr,
165 $tf_op:expr, $doc:expr, $($ximpl:tt)*) => {
166 #[doc = $doc]
167 #[derive(Debug)]
168 pub struct $name<T: TensorType> {
169 left: Expr<T>,
170 right: Expr<T>,
171 }
172
173 impl<T: TensorType> ops::$name for Expr<T> {
174 type Output = Expr<T>;
175
176 fn $fn_name(self, rhs: Expr<T>) -> Expr<T> {
177 Expr::new($name {
178 left: self,
179 right: rhs,
180 })
181 }
182 }
183
184 impl<T: TensorType> ops::$name<T> for Expr<T> {
185 type Output = Expr<T>;
186
187 fn $fn_name(self, rhs: T) -> Expr<T> {
188 Expr::new($name {
189 left: self,
190 right: Expr::from(rhs),
191 })
192 }
193 }
194
195 impl<T: TensorType> Display for $name<T> {
196 fn fmt(&self, f: &mut Formatter) -> Result<(), Error> {
197 if self.left.op_level() < OpLevel::$op_level {
198 write!(f, "({})", self.left)?;
199 } else {
200 write!(f, "{}", self.left)?;
201 }
202 write!(f, concat!(" ", $op, " "))?;
203 let paren = if $assoc {
204 self.right.op_level() < OpLevel::$op_level
205 } else {
206 self.right.op_level() <= OpLevel::$op_level
207 };
208 if paren {
209 write!(f, "({})", self.right)
210 } else {
211 write!(f, "{}", self.right)
212 }
213 }
214 }
215
216 impl<T: TensorType> ExprImpl<T> for $name<T> {
217 fn op_level(&self) -> OpLevel {
218 OpLevel::$op_level
219 }
220
221 fn children(&self) -> Vec<Box<dyn AnyExpr>> {
222 vec![Box::new(self.left.clone()), Box::new(self.right.clone())]
223 }
224
225 fn create_operation(&self, graph: &mut Graph, children: &[Operation],
226 id_gen: &mut dyn FnMut() -> String) -> Result<Operation, Status> {
227 let mut nd = graph.new_operation($tf_op, &id_gen())?;
228 nd.add_input(children[0].clone());
229 nd.add_input(children[1].clone());
230 nd.finish()
231 }
232
233 $($ximpl)*
234 }
235 }
236}
237
238impl_bin_op!(
239 Add,
240 add,
241 "+",
242 Add,
243 true,
244 "Add",
245 "Expression resulting from adding two subexpressions.",
246 fn derivative_by_variable(&self, var: &str) -> Result<Expr<T>, Status> {
247 Ok(self.left.derivative_by_variable(var)? + self.right.derivative_by_variable(var)?)
248 }
249);
250impl_bin_op!(
251 Sub,
252 sub,
253 "-",
254 Add,
255 false,
256 "Sub",
257 "Expression resulting from subtracting two subexpressions.",
258 fn derivative_by_variable(&self, var: &str) -> Result<Expr<T>, Status> {
259 Ok(self.left.derivative_by_variable(var)? - self.right.derivative_by_variable(var)?)
260 }
261);
262impl_bin_op!(
263 Mul,
264 mul,
265 "*",
266 Mul,
267 true,
268 "Mul",
269 "Expression resulting from multiplying two subexpressions.",
270 fn derivative_by_variable(&self, var: &str) -> Result<Expr<T>, Status> {
271 Ok(self.left.derivative_by_variable(var)? * self.right.clone()
272 + self.left.clone() * self.right.derivative_by_variable(var)?)
273 }
274);
275impl_bin_op!(
276 Div,
277 div,
278 "/",
279 Mul,
280 false,
281 "Div",
282 "Expression resulting from dividing two subexpressions.",
283 fn derivative_by_variable(&self, var: &str) -> Result<Expr<T>, Status> {
284 let num = self.left.derivative_by_variable(var)? * self.right.clone()
285 - self.left.clone() * self.right.derivative_by_variable(var)?;
286 let denom = self.right.clone() * self.right.clone();
287 Ok(num / denom)
288 }
289);
290impl_bin_op!(
291 Rem,
292 rem,
293 "%",
294 Mul,
295 false,
296 "Mod",
297 "Expression resulting from taking a modulus.",
298 fn derivative_by_variable(&self, var: &str) -> Result<Expr<T>, Status> {
299 Ok(self.left.derivative_by_variable(var)?
300 - TruncateDiv::new_expr(self.left.clone(), self.right.clone())
301 * self.right.derivative_by_variable(var)?)
302 }
303);
304
305#[derive(Debug)]
309pub struct TruncateDiv<T: TensorType> {
310 left: Expr<T>,
311 right: Expr<T>,
312}
313
314impl<T: TensorType> TruncateDiv<T> {
315 fn new(left: Expr<T>, right: Expr<T>) -> Self {
316 TruncateDiv { left, right }
317 }
318
319 pub fn new_expr(left: Expr<T>, right: Expr<T>) -> Expr<T> {
321 Expr::new(TruncateDiv::new(left, right))
322 }
323}
324
325impl<T: TensorType> Display for TruncateDiv<T> {
326 fn fmt(&self, f: &mut Formatter) -> Result<(), Error> {
327 write!(f, "{} // {}", self.left, self.right)
328 }
329}
330
331impl<T: TensorType> ExprImpl<T> for TruncateDiv<T> {
332 fn op_level(&self) -> OpLevel {
333 OpLevel::Mul
334 }
335
336 fn children(&self) -> Vec<Box<dyn AnyExpr>> {
337 vec![Box::new(self.left.clone()), Box::new(self.right.clone())]
338 }
339
340 fn create_operation(
341 &self,
342 graph: &mut Graph,
343 children: &[Operation],
344 id_gen: &mut dyn FnMut() -> String,
345 ) -> Result<Operation, Status> {
346 let mut nd = graph.new_operation("TruncateDiv", &id_gen())?;
347 nd.add_input(children[0].clone());
348 nd.add_input(children[1].clone());
349 nd.finish()
350 }
351
352 fn derivative_by_variable(&self, var: &str) -> Result<Expr<T>, Status> {
353 let diff = self.left.clone() - self.left.clone() % self.right.clone();
357 let term1 = self.right.clone() * diff.derivative_by_variable(var)?;
358 let term2 = diff * self.right.derivative_by_variable(var)?;
359 Ok((term1 - term2) / (self.right.clone() * self.right.clone()))
360 }
361}
362
363#[derive(Debug)]
367pub struct Neg<T: TensorType> {
368 expr: Expr<T>,
369}
370
371impl<T: TensorType> ops::Neg for Expr<T> {
372 type Output = Expr<T>;
373
374 fn neg(self) -> Expr<T> {
375 Expr::new(Neg { expr: self })
376 }
377}
378
379impl<T: TensorType> Display for Neg<T> {
380 fn fmt(&self, f: &mut Formatter) -> Result<(), Error> {
381 write!(f, "-")?;
382 if self.expr.op_level() <= OpLevel::Unary {
383 write!(f, "({})", self.expr)
384 } else {
385 write!(f, "{}", self.expr)
386 }
387 }
388}
389
390impl<T: TensorType> ExprImpl<T> for Neg<T> {
391 fn op_level(&self) -> OpLevel {
392 OpLevel::Unary
393 }
394
395 fn children(&self) -> Vec<Box<dyn AnyExpr>> {
396 vec![Box::new(self.expr.clone())]
397 }
398
399 fn create_operation(
400 &self,
401 graph: &mut Graph,
402 children: &[Operation],
403 id_gen: &mut dyn FnMut() -> String,
404 ) -> Result<Operation, Status> {
405 let mut nd = graph.new_operation("Neg", &id_gen())?;
406 nd.add_input(children[0].clone());
407 nd.finish()
408 }
409
410 fn derivative_by_variable(&self, var: &str) -> Result<Expr<T>, Status> {
411 Ok(-self.expr.derivative_by_variable(var)?)
412 }
413}
414
415#[derive(Debug)]
419pub struct Variable<T: TensorType> {
420 shape: Vec<u64>,
421 name: String,
422 phantom: PhantomData<T>,
423}
424
425impl<T: TensorType> Variable<T> {
426 fn new(shape: &[u64], name: &str) -> Self {
427 Variable {
428 shape: Vec::from(shape),
429 name: name.to_string(),
430 phantom: PhantomData,
431 }
432 }
433
434 pub fn new_expr(shape: &[u64], name: &str) -> Expr<T> {
436 Expr::new(Variable::new(shape, name))
437 }
438}
439
440impl<T: TensorType> Display for Variable<T> {
441 fn fmt(&self, f: &mut Formatter) -> Result<(), Error> {
442 write!(f, "{}", self.name)
443 }
444}
445
446impl<T: TensorType> ExprImpl<T> for Variable<T> {
447 fn op_level(&self) -> OpLevel {
448 OpLevel::Atom
449 }
450
451 fn children(&self) -> Vec<Box<dyn AnyExpr>> {
452 vec![]
453 }
454
455 fn create_operation(
456 &self,
457 graph: &mut Graph,
458 _children: &[Operation],
459 _id_gen: &mut dyn FnMut() -> String,
460 ) -> Result<Operation, Status> {
461 let mut nd = graph.new_operation("Variable", &self.name)?;
462 let shape = self
463 .shape
464 .iter()
465 .map(|dim_size| Some(*dim_size as i64))
466 .collect();
467
468 nd.set_attr_type("dtype", T::data_type()).unwrap();
469 nd.set_attr_shape("shape", &Shape(Some(shape))).unwrap();
470 nd.finish()
471 }
472
473 fn derivative_by_variable(&self, var: &str) -> Result<Expr<T>, Status> {
474 Ok(if var == self.name {
475 Expr::from(T::one())
476 } else {
477 Expr::from(T::zero())
478 })
479 }
480
481 fn shape_hint(&self) -> ShapeHint {
482 ShapeHint::Exactly(&self.shape)
483 }
484}
485
486#[derive(Debug)]
490pub struct Placeholder<T: TensorType> {
491 shape: Vec<u64>,
492 name: String,
493 phantom: PhantomData<T>,
494}
495
496impl<T: TensorType> Placeholder<T> {
497 fn new(shape: &[u64], name: &str) -> Self {
498 Placeholder {
499 shape: Vec::from(shape),
500 name: name.to_string(),
501 phantom: PhantomData,
502 }
503 }
504
505 pub fn new_expr(shape: &[u64], name: &str) -> Expr<T> {
507 Expr::new(Placeholder::new(shape, name))
508 }
509}
510
511impl<T: TensorType> Display for Placeholder<T> {
512 fn fmt(&self, f: &mut Formatter) -> Result<(), Error> {
513 write!(f, "{}", self.name)
514 }
515}
516
517impl<T: TensorType> ExprImpl<T> for Placeholder<T> {
518 fn op_level(&self) -> OpLevel {
519 OpLevel::Atom
520 }
521
522 fn children(&self) -> Vec<Box<dyn AnyExpr>> {
523 vec![]
524 }
525
526 fn create_operation(
527 &self,
528 graph: &mut Graph,
529 _children: &[Operation],
530 _id_gen: &mut dyn FnMut() -> String,
531 ) -> Result<Operation, Status> {
532 let mut nd = graph.new_operation("Placeholder", &self.name)?;
533 let shape = self
534 .shape
535 .iter()
536 .map(|dim_size| Some(*dim_size as i64))
537 .collect();
538
539 nd.set_attr_type("dtype", T::data_type()).unwrap();
540 nd.set_attr_shape("shape", &Shape(Some(shape))).unwrap();
541 nd.finish()
542 }
543
544 fn derivative_by_variable(&self, _var: &str) -> Result<Expr<T>, Status> {
545 Ok(Expr::from(T::zero()))
546 }
547
548 fn shape_hint(&self) -> ShapeHint {
549 ShapeHint::Exactly(&self.shape)
550 }
551}
552
553#[derive(Debug)]
557pub struct Constant<T: TensorType> {
558 tensor: Tensor<T>,
559}
560
561impl<T: TensorType> Constant<T> {
562 pub fn new(tensor: Tensor<T>) -> Self {
564 Constant { tensor }
565 }
566
567 pub fn new_expr(tensor: Tensor<T>) -> Expr<T> {
569 Expr::new(Constant { tensor })
570 }
571}
572
573impl<T: TensorType> Display for Constant<T> {
574 fn fmt(&self, f: &mut Formatter) -> Result<(), Error> {
575 write!(f, "{}", self.tensor)
576 }
577}
578
579impl<T: TensorType> ExprImpl<T> for Constant<T> {
580 fn op_level(&self) -> OpLevel {
581 OpLevel::Atom
582 }
583
584 fn children(&self) -> Vec<Box<dyn AnyExpr>> {
585 vec![]
586 }
587
588 fn create_operation(
589 &self,
590 graph: &mut Graph,
591 _children: &[Operation],
592 id_gen: &mut dyn FnMut() -> String,
593 ) -> Result<Operation, Status> {
594 let mut nd = graph.new_operation("Const", &id_gen())?;
595
596 nd.set_attr_type("dtype", T::data_type())?;
597 nd.set_attr_tensor("value", self.tensor.clone())?;
598 nd.finish()
599 }
600
601 fn derivative_by_variable(&self, _var: &str) -> Result<Expr<T>, Status> {
602 Ok(Expr::from(T::zero()))
603 }
604}
605
606#[derive(Debug)]
610pub struct Assign<T: TensorType> {
611 variable: Expr<T>,
612 value: Expr<T>,
613}
614
615impl<T: TensorType> Assign<T> {
616 fn new(variable: Expr<T>, value: Expr<T>) -> Self {
617 Assign { variable, value }
618 }
619
620 pub fn new_expr(variable: Expr<T>, value: Expr<T>) -> Expr<T> {
622 Expr::new(Assign::new(variable, value))
623 }
624
625 pub fn to(variable: Expr<T>, iterable: impl Iterator<Item = T>) -> crate::Result<Expr<T>> {
627 let constant = if let ShapeHint::Exactly(shape) = variable.expr.shape_hint() {
628 let values: Vec<_> = iterable
629 .take(shape.iter().product::<u64>() as usize)
630 .collect();
631
632 Constant::new_expr(Tensor::new(shape).with_values(&values)?)
633 } else {
634 return Err(invalid_arg!(
635 "Cannot assign to expression {} with unknown size!",
636 variable
637 ));
638 };
639
640 Ok(Assign::new_expr(variable, constant))
641 }
642}
643
644impl<T: TensorType> Display for Assign<T> {
645 fn fmt(&self, f: &mut Formatter) -> Result<(), Error> {
646 write!(f, "{} = {}", self.variable, self.value)
647 }
648}
649
650impl<T: TensorType> ExprImpl<T> for Assign<T> {
651 fn op_level(&self) -> OpLevel {
652 OpLevel::Assign
653 }
654
655 fn children(&self) -> Vec<Box<dyn AnyExpr>> {
656 vec![
657 Box::new(self.variable.clone()),
658 Box::new(self.value.clone()),
659 ]
660 }
661
662 fn create_operation(
663 &self,
664 graph: &mut Graph,
665 children: &[Operation],
666 id_gen: &mut dyn FnMut() -> String,
667 ) -> Result<Operation, Status> {
668 let mut nd = graph.new_operation("Assign", &id_gen())?;
669 nd.add_input(children[0].clone());
670 nd.add_input(children[1].clone());
671 nd.finish()
672 }
673
674 fn derivative_by_variable(&self, _var: &str) -> Result<Expr<T>, Status> {
675 Err(invalid_arg!("Cannot take the derivative of an assignment"))
676 }
677}
678
679pub trait AnyExpr: Debug {
685 fn key(&self) -> *const ();
687
688 fn children(&self) -> Vec<Box<dyn AnyExpr>>; fn create_operation(
698 &self,
699 graph: &mut Graph,
700 children: &[Operation],
701 id_gen: &mut dyn FnMut() -> String,
702 ) -> Result<Operation, Status>;
703
704 fn clone_box(&self) -> Box<dyn AnyExpr>;
709}
710
711impl<T: TensorType> AnyExpr for Expr<T> {
712 #[allow(trivial_casts)]
713 fn key(&self) -> *const () {
714 self.expr.as_ref() as *const dyn ExprImpl<T> as *const ()
715 }
716
717 fn children(&self) -> Vec<Box<dyn AnyExpr>> {
718 self.expr.children()
719 }
720
721 fn create_operation(
722 &self,
723 graph: &mut Graph,
724 children: &[Operation],
725 id_gen: &mut dyn FnMut() -> String,
726 ) -> Result<Operation, Status> {
727 self.expr.create_operation(graph, children, id_gen)
728 }
729
730 fn clone_box(&self) -> Box<dyn AnyExpr> {
731 Box::new(self.clone())
732 }
733}
734
735#[derive(Debug)]
736struct Key(Box<dyn AnyExpr>);
737
738impl PartialEq for Key {
739 fn eq(&self, other: &Key) -> bool {
740 self.0.key() == other.0.key()
741 }
742}
743
744impl Eq for Key {}
745
746impl Hash for Key {
747 fn hash<H>(&self, state: &mut H)
748 where
749 H: Hasher,
750 {
751 state.write_isize(self.0.key() as isize)
752 }
753}
754
755#[derive(Debug)]
757pub struct Compiler<'l> {
758 graph: &'l mut Graph,
759 operations: HashMap<Key, Operation>,
760 next_id: i32,
761}
762
763impl<'l> Compiler<'l> {
764 pub fn new(graph: &'l mut Graph) -> Self {
766 Compiler {
767 graph,
768 operations: HashMap::new(),
769 next_id: 0,
770 }
771 }
772
773 pub fn compile<T: TensorType>(&mut self, expr: Expr<T>) -> Result<Operation, Status> {
775 self.compile_any(Box::new(expr))
776 }
777
778 pub fn compile_any(&mut self, expr: Box<dyn AnyExpr>) -> Result<Operation, Status> {
780 let mut child_operations = vec![];
781 for child in expr.children() {
782 let key = Key(child.clone_box());
783 let value = self.operations.get(&key).cloned();
786 child_operations.push(match value {
787 Some(v) => v,
788 None => self.compile_any(child)?,
789 });
790 }
791 let mut next_id = self.next_id;
792 let result = expr.create_operation(self.graph, &child_operations, &mut || {
793 let id = format!("operation_{}", next_id);
794 next_id += 1;
795 id
796 });
797 self.next_id = next_id;
798 let operation = result?;
799 self.operations.insert(Key(expr), operation.clone());
800 Ok(operation)
801 }
802}
803
804#[cfg(test)]
807mod tests {
808 use super::super::Graph;
809 use super::*;
810
811 #[test]
812 fn test_display() {
813 assert_eq!("1 + 2 + 3", format!("{}", (Expr::from(1) + 2) + 3));
814 assert_eq!(
815 "1 + 2 + 3",
816 format!("{}", Expr::from(1) + (Expr::from(2) + 3))
817 );
818 assert_eq!("1 + 2 - 3", format!("{}", (Expr::from(1) + 2) - 3));
819 assert_eq!(
820 "1 - (2 + 3)",
821 format!("{}", Expr::from(1) - (Expr::from(2) + 3))
822 );
823
824 assert_eq!("(1 + 2) * 3", format!("{}", (Expr::from(1) + 2) * 3));
825 assert_eq!(
826 "1 * (2 + 3)",
827 format!("{}", Expr::from(1) * (Expr::from(2) + 3))
828 );
829 assert_eq!("1 * 2 * 3", format!("{}", (Expr::from(1) * 2) * 3));
830 assert_eq!(
831 "1 * 2 * 3",
832 format!("{}", Expr::from(1) * (Expr::from(2) * 3))
833 );
834
835 assert_eq!("(1 + 2) / 3", format!("{}", (Expr::from(1) + 2) / 3));
836 assert_eq!(
837 "1 / (2 + 3)",
838 format!("{}", Expr::from(1) / (Expr::from(2) + 3))
839 );
840 assert_eq!("1 * 2 / 3", format!("{}", (Expr::from(1) * 2) / 3));
841 assert_eq!(
842 "1 / (2 * 3)",
843 format!("{}", Expr::from(1) / (Expr::from(2) * 3))
844 );
845
846 assert_eq!("(1 + 2) % 3", format!("{}", (Expr::from(1) + 2) % 3));
847 assert_eq!(
848 "1 % (2 + 3)",
849 format!("{}", Expr::from(1) % (Expr::from(2) + 3))
850 );
851 assert_eq!("1 * 2 % 3", format!("{}", (Expr::from(1) * 2) % 3));
852 assert_eq!(
853 "1 % (2 * 3)",
854 format!("{}", Expr::from(1) % (Expr::from(2) * 3))
855 );
856
857 assert_eq!("-1", format!("{}", -Expr::from(1)));
858 assert_eq!("-(-1)", format!("{}", -(-Expr::from(1))));
859 assert_eq!("-(1 + 2)", format!("{}", -(Expr::from(1) + 2)));
860
861 assert_eq!("x", format!("{}", <Variable<f32>>::new(&vec![2, 3], "x")));
862
863 assert_eq!(
864 "x",
865 format!("{}", <Placeholder<f32>>::new(&vec![2, 3], "x"))
866 );
867
868 assert_eq!(
869 "x = 1 + 2",
870 format!(
871 "{}",
872 Assign::new(
873 <Placeholder<f32>>::new_expr(&vec![2, 3], "x"),
874 Expr::from(1.0f32) + 2.0f32
875 )
876 )
877 );
878 }
879
880 #[test]
881 fn test_compile() {
882 let mut g = Graph::new();
883
884 let x = <Placeholder<f32>>::new_expr(&vec![2, 3], "x");
885 let w = <Variable<f32>>::new_expr(&vec![2, 3], "w");
886
887 let mut compiler = Compiler::new(&mut g);
888
889 compiler
890 .compile(x * w.clone() / w.clone() % w.clone() + w.clone() - w.clone())
891 .unwrap();
892
893 compiler
894 .compile(Assign::to(w, ::std::iter::repeat(1.)).unwrap())
895 .unwrap();
896 }
897
898 #[test]
899 fn test_derivative_by_variable() {
900 let x = <Variable<f32>>::new_expr(&[], "x");
901 let y = <Variable<f32>>::new_expr(&[], "y");
902 for &(ref expected, ref expression) in [
903 ("0", Expr::from(1.0f32)),
904 ("1", x.clone()),
905 ("0", y.clone()),
906 ("1 + 0", x.clone() + y.clone()),
907 ("1 - 0", x.clone() - y.clone()),
908 ("1 * x + x * 1", x.clone() * x.clone()),
909 ("1 * y + x * 0", x.clone() * y.clone()),
910 (
911 "(1 * x + x * 1) * x + x * x * 1",
912 x.clone() * x.clone() * x.clone(),
913 ),
914 ("(1 * y - x * 0) / (y * y)", x.clone() / y.clone()),
915 ("1 - x // y * 0", x.clone() % y.clone()),
916 ("0 - y // x * 1", y.clone() % x.clone()),
917 (
918 "(y * (1 - (1 - x // y * 0)) - (x - x % y) * 0) / (y * y)",
919 TruncateDiv::new_expr(x.clone(), y.clone()),
920 ),
921 ]
922 .iter()
923 {
924 assert_eq!(
925 *expected,
926 format!("{}", expression.derivative_by_variable("x").unwrap())
927 );
928 }
929 }
930}