1#![deny(missing_docs)]
11use std::collections::HashMap;
12
13use crate::value::{Expr, Operation};
14
15pub struct CompiledExpr {
20 operations: Vec<Operation>,
21 lhs: Vec<Option<usize>>,
22 rhs: Vec<Option<usize>>,
23 pub results: Vec<f64>,
25 pub gradients: Vec<f64>,
27 is_learnable: Vec<bool>,
28 names_to_index: HashMap<String, usize>,
29}
30
31impl CompiledExpr {
32 fn consume_expr(&mut self, expr: Expr) {
33 let lhs = if let Some(operand1) = expr.operand1 {
34 self.consume_expr(*operand1);
35 Some(self.results.len() - 1)
36 } else {
37 None
38 };
39
40 let rhs = if let Some(operand2) = expr.operand2 {
41 self.consume_expr(*operand2);
42 Some(self.results.len() - 1)
43 } else {
44 None
45 };
46
47 self.lhs.push(lhs);
48 self.rhs.push(rhs);
49 self.results.push(expr.result);
50 self.operations.push(expr.operation);
51 self.gradients.push(expr.grad);
52 self.is_learnable.push(expr.is_learnable);
53 if let Some(name) = expr.name {
54 self.names_to_index.insert(name, self.results.len() - 1);
55 }
56 }
57
58 pub fn from_expr(expr: Expr) -> Self {
74 let parameter_count = expr.parameter_count(false);
75 let mut tape = CompiledExpr {
76 operations: Vec::with_capacity(parameter_count),
77 lhs: Vec::with_capacity(parameter_count),
78 rhs: Vec::with_capacity(parameter_count),
79 results: Vec::with_capacity(parameter_count),
80 gradients: Vec::with_capacity(parameter_count),
81 is_learnable: Vec::with_capacity(parameter_count),
82 names_to_index: HashMap::new(),
83 };
84
85 tape.consume_expr(expr);
86
87 tape
88 }
89
90 pub fn recalculate(&mut self) {
114 for i in 0..self.results.len() {
115 let operation = self.operations[i];
116 let lhs_index = self.lhs[i];
117 let rhs_index = self.rhs[i];
118
119 let lhs_value = if let Some(index) = lhs_index {
120 self.results[index]
121 } else {
122 0.0 };
124
125 let rhs_value = if let Some(index) = rhs_index {
126 self.results[index]
127 } else {
128 0.0 };
130
131 self.results[i] = match operation {
132 Operation::Add => lhs_value + rhs_value,
133 Operation::Sub => lhs_value - rhs_value,
134 Operation::Mul => lhs_value * rhs_value,
135 Operation::Div => lhs_value / rhs_value,
136 Operation::None => self.results[i], Operation::Tanh => lhs_value.tanh(),
138 Operation::Exp => lhs_value.exp(),
139 Operation::Pow => lhs_value.powf(rhs_value),
140 Operation::Log => lhs_value.ln(),
141 Operation::ReLU => lhs_value.max(0.0),
142 Operation::Neg => -lhs_value,
143 };
144 }
145 }
146
147 pub fn learn(&mut self, learning_rate: f64) {
179 self.gradients[self.results.len() - 1] = 1.0;
181
182 for i in (0..self.results.len()).rev() {
183 let operation = self.operations[i];
184 let lhs_index = self.lhs[i].unwrap_or(0);
185 let rhs_index = self.rhs[i].unwrap_or(0);
186
187 let lhs_result = if let Some(index) = self.lhs[i] {
188 self.results[index]
189 } else {
190 0.0 };
192
193 let rhs_result = if let Some(index) = self.rhs[i] {
194 self.results[index]
195 } else {
196 0.0 };
198 let result = self.results[i];
199 let gradient = self.gradients[i];
200
201 match operation {
202 Operation::None => {
204 if self.is_learnable[i] {
207 self.results[i] -= learning_rate * self.gradients[i];
208 }
209 }
210 Operation::Tanh => {
212 let tanh_grad = 1.0 - (result * result);
213 self.gradients[lhs_index] = gradient * tanh_grad;
214 }
215 Operation::Exp => {
216 self.gradients[lhs_index] = gradient * result;
217 }
218 Operation::ReLU => {
219 self.gradients[lhs_index] = if result > 0.0 {
220 1.0
221 } else {
222 0.0
223 };
224 }
225 Operation::Log => {
226 self.gradients[lhs_index] = gradient / result;
227 }
228 Operation::Neg => {
229 self.gradients[lhs_index] = -gradient;
230 }
231 Operation::Add => {
233 self.gradients[lhs_index] = gradient;
234 self.gradients[rhs_index] = gradient;
235 }
236 Operation::Sub => {
237 self.gradients[lhs_index] = gradient;
238 self.gradients[rhs_index] = -gradient;
239 }
240 Operation::Mul => {
241 self.gradients[lhs_index] = gradient * rhs_result;
242 self.gradients[rhs_index] = gradient * lhs_result;
243 }
244 Operation::Div => {
245 self.gradients[lhs_index] = gradient / rhs_result;
246 self.gradients[rhs_index] = -gradient * lhs_result / (rhs_result * rhs_result);
247 }
248 Operation::Pow => {
249 let exponent = rhs_result;
250 let base = lhs_result;
251
252 self.gradients[lhs_index] = gradient * exponent * base.powf(exponent - 1.0);
253 self.gradients[rhs_index] = gradient * lhs_result.ln() * result;
254 }
255 }
256 }
257 }
258
259 pub fn result(&self) -> f64 {
268 if self.results.is_empty() {
269 0.0
270 } else {
271 *self.results.last().unwrap()
272 }
273 }
274
275 pub fn get_grad_by_name(&self, name: &str) -> Option<f64> {
288 if let Some(&index) = self.names_to_index.get(name) {
289 return Some(self.gradients[index]);
290 }
291 None
292 }
293
294 pub fn set(&mut self, name: &str, value: f64) {
315 if let Some(&index) = self.names_to_index.get(name) {
316 self.results[index] = value;
317 }
318 }
319}
320
321#[cfg(test)]
322mod tests {
323 use super::*;
324
325 fn assert_float_eq(f1: f64, f2: f64) {
326 let places = 7;
327 let tolerance = 10.0_f64.powi(-places);
328 assert!((f1 - f2).abs() < tolerance, "{} != {} (tol: {})", f1, f2, tolerance);
329 }
330
331 #[test]
332 fn test_from_expr_multilevel() {
333 let a = Expr::new_leaf(2.0);
339 let b = Expr::new_leaf_with_name(3.0, "b");
340 let c = Expr::new_leaf(5.0);
341 let d = Expr::new_leaf_with_name(1.0, "d");
342
343 let add = a + b;
345
346 let sub = c - d;
348
349 let mul = add * sub;
351
352 let tape = CompiledExpr::from_expr(mul);
354
355 assert_eq!(tape.results.len(), 7);
357 assert_eq!(tape.operations.len(), 7);
358 assert_eq!(tape.lhs.len(), 7);
359 assert_eq!(tape.rhs.len(), 7);
360 assert_eq!(tape.gradients.len(), 7);
361
362 assert_eq!(tape.results[0], 2.0);
365 assert_eq!(tape.lhs[0], None);
366 assert_eq!(tape.rhs[0], None); assert_eq!(tape.operations[0], Operation::None);
368 assert_eq!(tape.gradients[0], 0.0); assert_eq!(tape.results[1], 3.0);
372 assert_eq!(tape.lhs[1], None);
373 assert_eq!(tape.rhs[1], None); assert_eq!(tape.operations[1], Operation::None);
375 assert_eq!(tape.gradients[1], 0.0); assert_eq!(tape.results[2], 5.0);
379 assert_eq!(tape.lhs[2], Some(0)); assert_eq!(tape.rhs[2], Some(1)); assert_eq!(tape.operations[2], Operation::Add);
382 assert_eq!(tape.gradients[2], 0.0); assert_eq!(tape.results[3], 5.0);
386 assert_eq!(tape.lhs[3], None);
387 assert_eq!(tape.rhs[3], None); assert_eq!(tape.operations[3], Operation::None);
389 assert_eq!(tape.gradients[3], 0.0); assert_eq!(tape.results[4], 1.0);
393 assert_eq!(tape.lhs[4], None);
394 assert_eq!(tape.rhs[4], None); assert_eq!(tape.operations[4], Operation::None);
396 assert_eq!(tape.gradients[4], 0.0); assert_eq!(tape.results[5], 4.0);
400 assert_eq!(tape.lhs[5], Some(3)); assert_eq!(tape.rhs[5], Some(4)); assert_eq!(tape.operations[5], Operation::Sub);
403 assert_eq!(tape.gradients[5], 0.0); assert_eq!(tape.results[6], 20.0);
407 assert_eq!(tape.lhs[6], Some(2)); assert_eq!(tape.rhs[6], Some(5)); assert_eq!(tape.operations[6], Operation::Mul);
410 assert_eq!(tape.gradients[6], 0.0); assert_eq!(tape.names_to_index.get("b"), Some(&1));
414 assert_eq!(tape.names_to_index.get("d"), Some(&4));
415 assert!(tape.names_to_index.get("a").is_none());
416 assert!(tape.names_to_index.get("c").is_none());
417 }
418
419 #[test]
420 fn test_recalculate() {
421 let a = Expr::new_leaf(2.0);
423 let b = Expr::new_leaf(3.0);
424 let expr = a + b;
425
426 let mut tape = CompiledExpr::from_expr(expr);
428
429 tape.recalculate();
431
432 assert_eq!(tape.results[2], 5.0); tape.results[0] = 4.0; tape.results[1] = 6.0; tape.recalculate();
438
439 assert_eq!(tape.results[2], 10.0); }
442
443 #[test]
444 fn test_learn_simple() {
445 let expr = Expr::new_leaf(1.0);
446 let mut tape = CompiledExpr::from_expr(expr);
447 assert_eq!(tape.result(), 1.0);
448
449 tape.learn(1e-01);
450 assert_eq!(tape.result(), 0.9); }
452
453 #[test]
454 fn test_learn_skips_non_learnable() {
455 let mut expr = Expr::new_leaf(1.0);
456 expr.is_learnable = false;
457 let mut tape = CompiledExpr::from_expr(expr);
458 assert_eq!(tape.result(), 1.0);
459
460 tape.learn(1e-01);
461 assert_eq!(tape.result(), 1.0);
462 }
463
464 #[test]
465 fn test_learn_multilevel() {
466 let expr = Expr::new_leaf(1.0);
467 let expr2 = expr.tanh();
468 let mut tape = CompiledExpr::from_expr(expr2);
469 assert_eq!(tape.result(), 0.7615941559557649); tape.learn(1e-09);
471 tape.recalculate();
472
473 assert_eq!(tape.result(), 0.7615941557793864);
474 }
475
476 #[test]
477 fn test_backpropagation_add() {
478 let mut operand1 = Expr::new_leaf(1.0);
479 operand1.name = Some("a".to_string());
480
481 let mut operand2 = Expr::new_leaf(2.0);
482 operand2.name = Some("b".to_string());
483
484 let expr3 = operand1 + operand2;
485 let mut tape = CompiledExpr::from_expr(expr3);
486
487 tape.learn(1e-09);
488
489 let grad_a = tape.get_grad_by_name("a").unwrap();
490 let grad_b = tape.get_grad_by_name("b").unwrap();
491 assert_eq!(grad_a, 1.0);
492 assert_eq!(grad_b, 1.0);
493 }
494
495 #[test]
496 fn test_backpropagation_sub() {
497 let mut operand1 = Expr::new_leaf(1.0);
498 operand1.name = Some("a".to_string());
499
500 let mut operand2 = Expr::new_leaf(2.0);
501 operand2.name = Some("b".to_string());
502
503 let expr3 = operand1 - operand2;
504 let mut tape = CompiledExpr::from_expr(expr3);
505 tape.learn(1e-09);
506
507 let grad_a = tape.get_grad_by_name("a").unwrap();
508 let grad_b = tape.get_grad_by_name("b").unwrap();
509 assert_eq!(grad_a, 1.0);
510 assert_eq!(grad_b, -1.0);
511 }
512
513 #[test]
514 fn test_backpropagation_mul() {
515 let mut operand1 = Expr::new_leaf(3.0);
516 operand1.name = Some("a".to_string());
517
518 let mut operand2 = Expr::new_leaf(4.0);
519 operand2.name = Some("b".to_string());
520
521 let expr3 = operand1 * operand2;
522 let mut tape = CompiledExpr::from_expr(expr3);
523
524 tape.learn(1e-09);
525
526 let grad_a = tape.get_grad_by_name("a").unwrap();
527 let grad_b = tape.get_grad_by_name("b").unwrap();
528 assert_eq!(grad_a, 4.0);
529 assert_eq!(grad_b, 3.0);
530 }
531
532 #[test]
533 fn test_backpropagation_div() {
534 let mut operand1 = Expr::new_leaf(3.0);
535 operand1.name = Some("a".to_string());
536
537 let mut operand2 = Expr::new_leaf(4.0);
538 operand2.name = Some("b".to_string());
539 let expr3 = operand1 / operand2;
540 let mut tape = CompiledExpr::from_expr(expr3);
541
542 tape.learn(1e-09);
543
544 let grad_a = tape.get_grad_by_name("a").unwrap();
545 let grad_b = tape.get_grad_by_name("b").unwrap();
546 assert_eq!(grad_a, 0.25);
547 assert_eq!(grad_b, -0.1875);
548 }
549
550 #[test]
551 fn test_backpropagation_tanh() {
552 let mut operand1 = Expr::new_leaf(0.0);
553 operand1.name = Some("a".to_string());
554 let expr2 = operand1.tanh();
555 let mut tape = CompiledExpr::from_expr(expr2);
556
557 tape.learn(1e-09);
558
559 let grad_a = tape.get_grad_by_name("a").unwrap();
560 assert_float_eq(grad_a, 1.0);
561 }
562
563 #[test]
564 fn test_backpropagation_relu() {
565 let mut operand1 = Expr::new_leaf(-1.0);
566 operand1.name = Some("a".to_string());
567 let expr2 = operand1.relu();
568 let mut tape = CompiledExpr::from_expr(expr2);
569
570 tape.learn(1e-09);
571
572 let grad_a = tape.get_grad_by_name("a").unwrap();
573 assert_eq!(grad_a, 0.0);
574 }
575
576 #[test]
577 fn test_backpropagation_exp() {
578 let mut operand1 = Expr::new_leaf(0.0);
579 operand1.name = Some("a".to_string());
580 let expr2 = operand1.exp();
581 let mut tape = CompiledExpr::from_expr(expr2);
582
583 tape.learn(1e-09);
584
585 let grad_a = tape.get_grad_by_name("a").unwrap();
586 assert_eq!(grad_a, 1.0);
587 }
588
589 #[test]
590 fn test_backpropagation_pow() {
591 let mut operand1 = Expr::new_leaf(2.0);
592 operand1.name = Some("a".to_string());
593 let mut operand2 = Expr::new_leaf(3.0);
594 operand2.name = Some("b".to_string());
595 let expr3 = operand1.pow(operand2);
596 let mut tape = CompiledExpr::from_expr(expr3);
597
598 tape.learn(1e-09);
599
600 let grad_a = tape.get_grad_by_name("a").unwrap();
601 let grad_b = tape.get_grad_by_name("b").unwrap();
602 assert_eq!(grad_a, 12.0);
603 assert_eq!(grad_b, 5.545177444479562);
604 }
605
606 #[test]
607 fn test_backpropagation_mixed_tree() {
608 let mut operand1 = Expr::new_leaf(1.0);
609 operand1.name = Some("operand1".to_string());
610 let mut operand2 = Expr::new_leaf(2.0);
611 operand2.name = Some("operand2".to_string());
612 let mut expr3 = operand1 + operand2;
613 expr3.name = Some("expr3".to_string());
614 let expr4 = expr3.tanh();
615 let mut tape = CompiledExpr::from_expr(expr4);
616
617 tape.learn(1e-09);
618
619 let expr3_grad = tape.get_grad_by_name("expr3").unwrap();
620 let operand1_grad = tape.get_grad_by_name("operand1").unwrap();
621 let operand2_grad = tape.get_grad_by_name("operand2").unwrap();
622
623 assert_eq!(expr3_grad, 0.009866037165440211);
624 assert_eq!(operand1_grad, 0.009866037165440211);
625 assert_eq!(operand2_grad, 0.009866037165440211);
626 }
627
628 #[test]
629 fn test_backpropagation_karpathys_example() {
630 let mut x1 = Expr::new_leaf(2.0);
631 x1.name = Some("x1".to_string());
632 let mut x2 = Expr::new_leaf(0.0);
633 x2.name = Some("x2".to_string());
634 let mut w1 = Expr::new_leaf(-3.0);
635 w1.name = Some("w1".to_string());
636 let mut w2 = Expr::new_leaf(1.0);
637 w2.name = Some("w2".to_string());
638 let mut b = Expr::new_leaf(6.8813735870195432);
639 b.name = Some("b".to_string());
640
641 let mut x1w1 = x1 * w1;
642 x1w1.name = Some("x1w1".to_string());
643 let mut x2w2 = x2 * w2;
644 x2w2.name = Some("x2w2".to_string());
645 let mut x1w1_x2w2 = x1w1 + x2w2;
646 x1w1_x2w2.name = Some("x1w1_x2w2".to_string());
647 let mut n = x1w1_x2w2 + b;
648 n.name = Some("n".to_string());
649 let o = n.tanh();
650 let mut tape = CompiledExpr::from_expr(o);
651
652 tape.learn(1e-09);
653
654 let n_grad = tape.get_grad_by_name("n").unwrap();
655 assert_float_eq(n_grad, 0.5);
656
657 let x1w1_x2w2_grad = tape.get_grad_by_name("x1w1_x2w2").unwrap();
658 assert_float_eq(x1w1_x2w2_grad, 0.5);
659
660 let b_grad = tape.get_grad_by_name("b").unwrap();
661 assert_float_eq(b_grad, 0.5);
662
663 let x1w1_grad = tape.get_grad_by_name("x1w1").unwrap();
664 assert_float_eq(x1w1_grad, 0.5);
665
666 let x2w2_grad = tape.get_grad_by_name("x2w2").unwrap();
667 assert_float_eq(x2w2_grad, 0.5);
668
669 let x1_grad = tape.get_grad_by_name("x1").unwrap();
670 assert_float_eq(x1_grad, -1.5);
671
672 let w1_grad = tape.get_grad_by_name("w1").unwrap();
673 assert_float_eq(w1_grad, 1.0);
674
675 let x2_grad = tape.get_grad_by_name("x2").unwrap();
676 assert_float_eq(x2_grad, 0.5);
677
678 let w2_grad = tape.get_grad_by_name("w2").unwrap();
679 assert_float_eq(w2_grad, 0.0);
680 }
681}