amari_calculus/fields/
scalar_field.rs

1//! Scalar field implementation
2
3use crate::{CalculusError, CalculusResult};
4
5/// A scalar field f: ℝⁿ → ℝ
6///
7/// Represents a function that maps points in n-dimensional space to scalar values.
8///
9/// # Examples
10///
11/// ```
12/// use amari_calculus::ScalarField;
13///
14/// // Define f(x, y) = x² + y²
15/// let f = ScalarField::<3, 0, 0>::new(|coords| {
16///     coords[0].powi(2) + coords[1].powi(2)
17/// });
18///
19/// // Evaluate at (3, 4)
20/// let value = f.evaluate(&[3.0, 4.0, 0.0]);
21/// assert!((value - 25.0).abs() < 1e-10);
22/// ```
23#[derive(Clone)]
24pub struct ScalarField<const P: usize, const Q: usize, const R: usize> {
25    /// The function defining the field
26    function: fn(&[f64]) -> f64,
27    /// Domain dimension
28    dim: usize,
29}
30
31impl<const P: usize, const Q: usize, const R: usize> ScalarField<P, Q, R> {
32    /// Create a new scalar field from a function
33    ///
34    /// # Arguments
35    ///
36    /// * `function` - Function mapping coordinates to scalar values
37    ///
38    /// # Examples
39    ///
40    /// ```
41    /// use amari_calculus::ScalarField;
42    ///
43    /// // Quadratic function f(x, y) = x² + y²
44    /// let f = ScalarField::<3, 0, 0>::new(|coords| {
45    ///     coords[0].powi(2) + coords[1].powi(2)
46    /// });
47    /// ```
48    pub fn new(function: fn(&[f64]) -> f64) -> Self {
49        Self {
50            function,
51            dim: P + Q + R,
52        }
53    }
54
55    /// Create a scalar field with explicit dimension
56    ///
57    /// Useful when the field dimension doesn't match the algebra dimension.
58    pub fn with_dimension(function: fn(&[f64]) -> f64, dim: usize) -> Self {
59        Self { function, dim }
60    }
61
62    /// Evaluate the scalar field at a point
63    ///
64    /// # Arguments
65    ///
66    /// * `coords` - Coordinates of the point
67    ///
68    /// # Returns
69    ///
70    /// The scalar value at the point
71    ///
72    /// # Errors
73    ///
74    /// Returns error if coordinate dimension doesn't match field dimension
75    pub fn evaluate(&self, coords: &[f64]) -> f64 {
76        (self.function)(coords)
77    }
78
79    /// Get the domain dimension
80    pub fn dimension(&self) -> usize {
81        self.dim
82    }
83
84    /// Compute numerical derivative along coordinate axis
85    ///
86    /// Uses centered difference: f'(x) ≈ (f(x+h) - f(x-h)) / (2h)
87    ///
88    /// # Arguments
89    ///
90    /// * `coords` - Point at which to compute derivative
91    /// * `axis` - Coordinate axis index (0 = x, 1 = y, etc.)
92    /// * `h` - Step size (default: 1e-5)
93    pub fn partial_derivative(&self, coords: &[f64], axis: usize, h: f64) -> CalculusResult<f64> {
94        if axis >= self.dim {
95            return Err(CalculusError::InvalidDimension {
96                expected: self.dim,
97                got: axis,
98            });
99        }
100
101        let mut coords_plus = coords.to_vec();
102        let mut coords_minus = coords.to_vec();
103
104        coords_plus[axis] += h;
105        coords_minus[axis] -= h;
106
107        let f_plus = self.evaluate(&coords_plus);
108        let f_minus = self.evaluate(&coords_minus);
109
110        Ok((f_plus - f_minus) / (2.0 * h))
111    }
112
113    // Note: Methods like add(), mul(), and scale() that combine fields
114    // are not implemented because function pointers cannot capture variables.
115    // Users should create new ScalarField instances manually when combining fields.
116}
117
118impl<const P: usize, const Q: usize, const R: usize> std::fmt::Debug for ScalarField<P, Q, R> {
119    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
120        f.debug_struct("ScalarField")
121            .field("dim", &self.dim)
122            .field("signature", &format!("Cl({},{},{})", P, Q, R))
123            .finish()
124    }
125}
126
127#[cfg(test)]
128mod tests {
129    use super::*;
130
131    #[test]
132    fn test_scalar_field_evaluation() {
133        // f(x, y) = x² + y²
134        let f = ScalarField::<3, 0, 0>::new(|coords| coords[0].powi(2) + coords[1].powi(2));
135
136        assert!((f.evaluate(&[0.0, 0.0, 0.0]) - 0.0).abs() < 1e-10);
137        assert!((f.evaluate(&[3.0, 4.0, 0.0]) - 25.0).abs() < 1e-10);
138        assert!((f.evaluate(&[1.0, 1.0, 0.0]) - 2.0).abs() < 1e-10);
139    }
140
141    #[test]
142    fn test_partial_derivative() {
143        // f(x, y) = x² + y²
144        // ∂f/∂x = 2x, ∂f/∂y = 2y
145        let f = ScalarField::<3, 0, 0>::new(|coords| coords[0].powi(2) + coords[1].powi(2));
146
147        let df_dx = f.partial_derivative(&[3.0, 4.0, 0.0], 0, 1e-5).unwrap();
148        let df_dy = f.partial_derivative(&[3.0, 4.0, 0.0], 1, 1e-5).unwrap();
149
150        assert!(
151            (df_dx - 6.0).abs() < 1e-6,
152            "∂f/∂x should be 6.0, got {}",
153            df_dx
154        );
155        assert!(
156            (df_dy - 8.0).abs() < 1e-6,
157            "∂f/∂y should be 8.0, got {}",
158            df_dy
159        );
160    }
161
162    #[test]
163    fn test_scalar_field_combination() {
164        // Test that we can manually combine scalar fields
165        let _f = ScalarField::<3, 0, 0>::new(|coords| coords[0] + coords[1]);
166        let _g = ScalarField::<3, 0, 0>::new(|coords| coords[0] * coords[1]);
167
168        // h = f + g
169        let h =
170            ScalarField::<3, 0, 0>::new(|coords| (coords[0] + coords[1]) + (coords[0] * coords[1]));
171        assert!((h.evaluate(&[2.0, 3.0, 0.0]) - 11.0).abs() < 1e-10); // (2+3) + (2*3) = 11
172
173        // h = f * g
174        let h =
175            ScalarField::<3, 0, 0>::new(|coords| (coords[0] + coords[1]) * (coords[0] * coords[1]));
176        assert!((h.evaluate(&[2.0, 3.0, 0.0]) - 30.0).abs() < 1e-10); // (2+3) * (2*3) = 30
177
178        // h = 2 * f
179        let h = ScalarField::<3, 0, 0>::new(|coords| 2.0 * (coords[0] + coords[1]));
180        assert!((h.evaluate(&[2.0, 3.0, 0.0]) - 10.0).abs() < 1e-10); // 2 * (2+3) = 10
181    }
182}