1use super::utils::*;
2use super::component::Component;
3use super::operators::Operator::*;
4
5fn chain_rule(left: Component, right: &Component) -> Component {
6 create_binary(Multiply, left, derive_component(right))
7}
8
9pub fn derive_component(expr: &Component) -> Component {
10 match expr {
11 Component::Number(_) => Component::Number(0.0),
12 Component::Variable(_) => Component::Number(1.0),
13
14 Component::Function { operator, values } => match operator {
15 Add => create_binary(
16 Add,
17 derive_component(&values[0]),
18 derive_component(&values[1]),
19 ),
20
21 Subtract => create_binary(
22 Subtract,
23 derive_component(&values[0]),
24 derive_component(&values[1]),
25 ),
26
27 Multiply => create_binary(
28 Add,
29 chain_rule(values[0].clone(), &values[1]),
30 chain_rule(values[1].clone(), &values[0]),
31 ),
32
33 Exponent | Pow => {
34 if let Component::Number(f) = &values[1] {
36 chain_rule(
37 create_binary(
38 Multiply,
39 Component::Number(*f),
40 create_binary(Exponent, values[0].clone(), Component::Number(f - 1.0)),
41 ),
42 &values[0],
43 )
44 } else if let Component::Number(f) = &values[0] {
46 create_binary(
47 Multiply,
48 create_unary(Ln, Component::Number(*f)),
49 chain_rule(expr.clone(), &expr),
50 )
51 } else {
52 create_binary(
54 Multiply,
55 expr.clone(),
56 create_binary(
57 Add,
58 create_binary(
59 Multiply,
60 derive_component(&values[1]),
61 create_unary(Ln, values[0].clone()),
62 ),
63 create_binary(
64 Multiply,
65 values[1].clone(),
66 create_binary(Divide, derive_component(&values[0]), values[0].clone()),
67 ),
68 ),
69 )
70 }
71 }
72
73 Log => chain_rule(
74 create_binary(
75 Divide,
76 derive_component(&values[0]),
77 create_binary(
78 Multiply,
79 create_unary(Ln, values[1].clone()),
80 values[0].clone(),
81 ),
82 ),
83 &values[0],
84 ),
85
86 Ln => chain_rule(
87 create_binary(Divide, derive_component(&values[0]), values[0].clone()),
88 &values[0],
89 ),
90
91 Sin => chain_rule(create_unary(Cos, values[0].clone()), &values[0]),
92
93 Cos => chain_rule(
94 create_binary(
95 Multiply,
96 Component::Number(-1.0),
97 create_unary(Sin, values[0].clone()),
98 ),
99 &values[0],
100 ),
101
102 Tan => chain_rule(
103 create_binary(
104 Pow,
105 create_unary(Sec, values[0].clone()),
106 Component::Number(2.0),
107 ),
108 &values[0],
109 ),
110
111 Sec => chain_rule(
112 create_binary(
113 Multiply,
114 create_unary(Sec, values[0].clone()),
115 create_unary(Tan, values[0].clone()),
116 ),
117 &values[0],
118 ),
119
120 Csc => chain_rule(
121 create_binary(
122 Multiply,
123 Component::Number(-1.0),
124 create_binary(
125 Multiply,
126 create_unary(Csc, values[0].clone()),
127 create_unary(Cot, values[0].clone()),
128 ),
129 ),
130 &values[0],
131 ),
132
133 Cot => chain_rule(
134 create_binary(
135 Multiply,
136 Component::Number(-1.0),
137 create_binary(
138 Pow,
139 create_unary(Csc, values[0].clone()),
140 Component::Number(2.0),
141 ),
142 ),
143 &values[0],
144 ),
145
146 _ => Component::End,
147 },
148 _ => Component::End,
149 }
150}