use_calculus/
derivative.rs1use crate::CalculusError;
2
3#[derive(Debug, Clone, Copy, PartialEq)]
5pub struct Differentiator {
6 step: f64,
7}
8
9impl Differentiator {
10 #[must_use]
12 pub const fn new(step: f64) -> Self {
13 Self { step }
14 }
15
16 pub fn try_new(step: f64) -> Result<Self, CalculusError> {
38 CalculusError::validate_step(step)?;
39 Ok(Self::new(step))
40 }
41
42 pub fn validate(self) -> Result<Self, CalculusError> {
48 Self::try_new(self.step)
49 }
50
51 #[must_use]
53 pub const fn step(&self) -> f64 {
54 self.step
55 }
56
57 pub fn derivative_at<F>(self, function: F, at: f64) -> Result<f64, CalculusError>
76 where
77 F: FnMut(f64) -> f64,
78 {
79 central_difference(function, at, self.step)
80 }
81
82 pub fn second_derivative_at<F>(self, function: F, at: f64) -> Result<f64, CalculusError>
89 where
90 F: FnMut(f64) -> f64,
91 {
92 second_central_difference(function, at, self.step)
93 }
94}
95
96#[must_use = "derivative estimates should be used or handled"]
103pub fn central_difference<F>(mut function: F, at: f64, step: f64) -> Result<f64, CalculusError>
104where
105 F: FnMut(f64) -> f64,
106{
107 let at = CalculusError::validate_point("at", at)?;
108 let step = CalculusError::validate_step(step)?;
109 let left = evaluate(&mut function, at - step)?;
110 let right = evaluate(&mut function, at + step)?;
111
112 Ok((right - left) / (2.0 * step))
113}
114
115#[must_use = "second-derivative estimates should be used or handled"]
122pub fn second_central_difference<F>(
123 mut function: F,
124 at: f64,
125 step: f64,
126) -> Result<f64, CalculusError>
127where
128 F: FnMut(f64) -> f64,
129{
130 let at = CalculusError::validate_point("at", at)?;
131 let step = CalculusError::validate_step(step)?;
132 let left = evaluate(&mut function, at - step)?;
133 let center = evaluate(&mut function, at)?;
134 let right = evaluate(&mut function, at + step)?;
135 let step_squared = step * step;
136 let numerator = (-2.0_f64).mul_add(center, left + right);
137
138 Ok(numerator / step_squared)
139}
140
141fn evaluate<F>(function: &mut F, input: f64) -> Result<f64, CalculusError>
142where
143 F: FnMut(f64) -> f64,
144{
145 let input = CalculusError::validate_point("sample", input)?;
146 let value = function(input);
147
148 CalculusError::validate_evaluation(input, value)
149}
150
151#[cfg(test)]
152mod tests {
153 use super::{CalculusError, Differentiator, central_difference, second_central_difference};
154
155 fn assert_close(left: f64, right: f64, tolerance: f64) {
156 assert!(
157 (left - right).abs() <= tolerance,
158 "expected {left} to be within {tolerance} of {right}"
159 );
160 }
161
162 #[test]
163 fn validates_differentiator_steps() {
164 assert!(matches!(
165 Differentiator::try_new(f64::INFINITY),
166 Err(CalculusError::NonFiniteStep(f64::INFINITY))
167 ));
168 assert!(matches!(
169 Differentiator::try_new(0.0),
170 Err(CalculusError::NonPositiveStep(0.0))
171 ));
172 }
173
174 #[test]
175 fn computes_first_derivatives() -> Result<(), CalculusError> {
176 let slope = central_difference(|x| x.powi(2), 3.0, 1.0e-5)?;
177
178 assert_close(slope, 6.0, 1.0e-6);
179 Ok(())
180 }
181
182 #[test]
183 fn computes_second_derivatives() -> Result<(), CalculusError> {
184 let curvature = second_central_difference(|x| x.powi(2), 1.5, 1.0e-4)?;
185
186 assert_close(curvature, 2.0, 1.0e-6);
187 Ok(())
188 }
189
190 #[test]
191 fn rejects_non_finite_points() {
192 assert!(matches!(
193 central_difference(|x| x, f64::NAN, 1.0e-5),
194 Err(CalculusError::NonFinitePoint { name: "at", .. })
195 ));
196 }
197
198 #[test]
199 fn rejects_non_finite_evaluations() {
200 assert!(matches!(
201 central_difference(|_| f64::NAN, 1.0, 1.0e-5),
202 Err(CalculusError::NonFiniteEvaluation { .. })
203 ));
204 }
205}