amari_calculus/fields/
scalar_field.rs1use crate::{CalculusError, CalculusResult};
4
5#[derive(Clone)]
24pub struct ScalarField<const P: usize, const Q: usize, const R: usize> {
25 function: fn(&[f64]) -> f64,
27 dim: usize,
29}
30
31impl<const P: usize, const Q: usize, const R: usize> ScalarField<P, Q, R> {
32 pub fn new(function: fn(&[f64]) -> f64) -> Self {
49 Self {
50 function,
51 dim: P + Q + R,
52 }
53 }
54
55 pub fn with_dimension(function: fn(&[f64]) -> f64, dim: usize) -> Self {
59 Self { function, dim }
60 }
61
62 pub fn evaluate(&self, coords: &[f64]) -> f64 {
76 (self.function)(coords)
77 }
78
79 pub fn dimension(&self) -> usize {
81 self.dim
82 }
83
84 pub fn partial_derivative(&self, coords: &[f64], axis: usize, h: f64) -> CalculusResult<f64> {
94 if axis >= self.dim {
95 return Err(CalculusError::InvalidDimension {
96 expected: self.dim,
97 got: axis,
98 });
99 }
100
101 let mut coords_plus = coords.to_vec();
102 let mut coords_minus = coords.to_vec();
103
104 coords_plus[axis] += h;
105 coords_minus[axis] -= h;
106
107 let f_plus = self.evaluate(&coords_plus);
108 let f_minus = self.evaluate(&coords_minus);
109
110 Ok((f_plus - f_minus) / (2.0 * h))
111 }
112
113 }
117
118impl<const P: usize, const Q: usize, const R: usize> std::fmt::Debug for ScalarField<P, Q, R> {
119 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
120 f.debug_struct("ScalarField")
121 .field("dim", &self.dim)
122 .field("signature", &format!("Cl({},{},{})", P, Q, R))
123 .finish()
124 }
125}
126
127#[cfg(test)]
128mod tests {
129 use super::*;
130
131 #[test]
132 fn test_scalar_field_evaluation() {
133 let f = ScalarField::<3, 0, 0>::new(|coords| coords[0].powi(2) + coords[1].powi(2));
135
136 assert!((f.evaluate(&[0.0, 0.0, 0.0]) - 0.0).abs() < 1e-10);
137 assert!((f.evaluate(&[3.0, 4.0, 0.0]) - 25.0).abs() < 1e-10);
138 assert!((f.evaluate(&[1.0, 1.0, 0.0]) - 2.0).abs() < 1e-10);
139 }
140
141 #[test]
142 fn test_partial_derivative() {
143 let f = ScalarField::<3, 0, 0>::new(|coords| coords[0].powi(2) + coords[1].powi(2));
146
147 let df_dx = f.partial_derivative(&[3.0, 4.0, 0.0], 0, 1e-5).unwrap();
148 let df_dy = f.partial_derivative(&[3.0, 4.0, 0.0], 1, 1e-5).unwrap();
149
150 assert!(
151 (df_dx - 6.0).abs() < 1e-6,
152 "∂f/∂x should be 6.0, got {}",
153 df_dx
154 );
155 assert!(
156 (df_dy - 8.0).abs() < 1e-6,
157 "∂f/∂y should be 8.0, got {}",
158 df_dy
159 );
160 }
161
162 #[test]
163 fn test_scalar_field_combination() {
164 let _f = ScalarField::<3, 0, 0>::new(|coords| coords[0] + coords[1]);
166 let _g = ScalarField::<3, 0, 0>::new(|coords| coords[0] * coords[1]);
167
168 let h =
170 ScalarField::<3, 0, 0>::new(|coords| (coords[0] + coords[1]) + (coords[0] * coords[1]));
171 assert!((h.evaluate(&[2.0, 3.0, 0.0]) - 11.0).abs() < 1e-10); let h =
175 ScalarField::<3, 0, 0>::new(|coords| (coords[0] + coords[1]) * (coords[0] * coords[1]));
176 assert!((h.evaluate(&[2.0, 3.0, 0.0]) - 30.0).abs() < 1e-10); let h = ScalarField::<3, 0, 0>::new(|coords| 2.0 * (coords[0] + coords[1]));
180 assert!((h.evaluate(&[2.0, 3.0, 0.0]) - 10.0).abs() < 1e-10); }
182}