Skip to main content

celestial_pointing/
solver.rs

1use crate::error::{Error, Result};
2use crate::observation::Observation;
3use crate::terms::Term;
4use nalgebra::{DMatrix, DVector};
5
6#[derive(Clone)]
7pub struct FitResult {
8    pub coefficients: Vec<f64>,
9    pub sigma: Vec<f64>,
10    pub sky_rms: f64,
11    pub term_names: Vec<String>,
12}
13
14pub fn fit_model(
15    observations: &[&Observation],
16    terms: &[Box<dyn Term>],
17    fixed: &[bool],
18    coefficients: &[f64],
19    latitude: f64,
20) -> Result<FitResult> {
21    let free_count = fixed.iter().filter(|&&f| !f).count();
22    if free_count == 0 && !terms.is_empty() {
23        return Err(Error::Fit("all terms are fixed".into()));
24    }
25    if terms.is_empty() {
26        return Err(Error::Fit("no terms to fit".into()));
27    }
28    if observations.len() < free_count {
29        return Err(Error::Fit("insufficient observations".into()));
30    }
31
32    let free_indices: Vec<usize> = fixed
33        .iter()
34        .enumerate()
35        .filter(|(_, &f)| !f)
36        .map(|(i, _)| i)
37        .collect();
38    let fixed_indices: Vec<usize> = fixed
39        .iter()
40        .enumerate()
41        .filter(|(_, &f)| f)
42        .map(|(i, _)| i)
43        .collect();
44
45    let mut b = build_residuals(observations);
46    let w = build_weights(observations);
47    let a_full = build_design_matrix(observations, terms, latitude);
48
49    subtract_fixed_contributions(&mut b, &a_full, coefficients, &fixed_indices);
50
51    let a_free = extract_columns(&a_full, &free_indices);
52    let free_coeffs = solve_weighted(&a_free, &b, &w)?;
53    let free_residuals = &b - &a_free * &free_coeffs;
54
55    let mut all_coeffs = vec![0.0; terms.len()];
56    for (fi, &idx) in free_indices.iter().enumerate() {
57        all_coeffs[idx] = free_coeffs[fi];
58    }
59
60    let full_residuals = build_residuals(observations);
61    let full_a = &a_full;
62    let all_coeffs_dv = DVector::from_vec(all_coeffs.clone());
63    let actual_residuals = &full_residuals - full_a * &all_coeffs_dv;
64
65    let sigma = compute_sigma_free(&a_free, &free_residuals, &w, &free_indices, terms.len());
66    let sky_rms = compute_sky_rms(&actual_residuals, observations);
67    let term_names = terms.iter().map(|t| t.name().to_string()).collect();
68
69    Ok(FitResult {
70        coefficients: all_coeffs,
71        sigma,
72        sky_rms,
73        term_names,
74    })
75}
76
77fn subtract_fixed_contributions(
78    b: &mut DVector<f64>,
79    a: &DMatrix<f64>,
80    coefficients: &[f64],
81    fixed_indices: &[usize],
82) {
83    for &idx in fixed_indices {
84        let coeff = coefficients[idx];
85        for row in 0..a.nrows() {
86            b[row] -= a[(row, idx)] * coeff;
87        }
88    }
89}
90
91fn extract_columns(a: &DMatrix<f64>, cols: &[usize]) -> DMatrix<f64> {
92    let rows = a.nrows();
93    let m = cols.len();
94    let mut out = DMatrix::zeros(rows, m);
95    for (j, &col) in cols.iter().enumerate() {
96        for i in 0..rows {
97            out[(i, j)] = a[(i, col)];
98        }
99    }
100    out
101}
102
103pub fn build_residuals(observations: &[&Observation]) -> DVector<f64> {
104    let n = observations.len();
105    let mut b = DVector::zeros(2 * n);
106    for (i, obs) in observations.iter().enumerate() {
107        b[2 * i] = (obs.actual_ha - obs.commanded_ha).arcseconds();
108        b[2 * i + 1] = (obs.observed_dec - obs.catalog_dec).arcseconds();
109    }
110    b
111}
112
113fn build_weights(observations: &[&Observation]) -> DVector<f64> {
114    let n = observations.len();
115    let mut w = DVector::zeros(2 * n);
116    for (i, obs) in observations.iter().enumerate() {
117        let cos_dec = libm::cos(obs.catalog_dec.radians());
118        w[2 * i] = cos_dec * cos_dec;
119        w[2 * i + 1] = 1.0;
120    }
121    w
122}
123
124fn build_design_matrix(
125    observations: &[&Observation],
126    terms: &[Box<dyn Term>],
127    lat: f64,
128) -> DMatrix<f64> {
129    let n = observations.len();
130    let m = terms.len();
131    let mut a = DMatrix::zeros(2 * n, m);
132    for (i, obs) in observations.iter().enumerate() {
133        let h = obs.commanded_ha.radians();
134        let dec = obs.catalog_dec.radians();
135        let pier = obs.pier_side.sign();
136        for (j, term) in terms.iter().enumerate() {
137            let (jh, jd) = term.jacobian_equatorial(h, dec, lat, pier);
138            a[(2 * i, j)] = jh;
139            a[(2 * i + 1, j)] = jd;
140        }
141    }
142    a
143}
144
145fn solve_weighted(a: &DMatrix<f64>, b: &DVector<f64>, w: &DVector<f64>) -> Result<DVector<f64>> {
146    let sqrt_w = w.map(libm::sqrt);
147    let rows = a.nrows();
148    let cols = a.ncols();
149    let a_w = DMatrix::from_fn(rows, cols, |i, j| a[(i, j)] * sqrt_w[i]);
150    let b_w = DVector::from_fn(rows, |i, _| b[i] * sqrt_w[i]);
151    let svd = a_w.svd(true, true);
152    svd.solve(&b_w, 1e-10)
153        .map_err(|e| Error::Fit(format!("SVD solve failed: {}", e)))
154}
155
156fn compute_sigma_free(
157    a_free: &DMatrix<f64>,
158    residuals: &DVector<f64>,
159    w: &DVector<f64>,
160    free_indices: &[usize],
161    total_terms: usize,
162) -> Vec<f64> {
163    let n = a_free.nrows();
164    let m = a_free.ncols();
165    let dof = n.saturating_sub(m).max(1);
166    let sqrt_w = w.map(libm::sqrt);
167    let a_w = DMatrix::from_fn(n, m, |i, j| a_free[(i, j)] * sqrt_w[i]);
168    let r_w = DVector::from_fn(n, |i, _| residuals[i] * sqrt_w[i]);
169    let s2 = r_w.dot(&r_w) / dof as f64;
170    let ata = a_w.transpose() * &a_w;
171    let free_sigma = match ata.try_inverse() {
172        Some(inv) => (0..m)
173            .map(|j| libm::sqrt((s2 * inv[(j, j)]).abs()))
174            .collect::<Vec<_>>(),
175        None => vec![f64::NAN; m],
176    };
177    let mut sigma = vec![0.0; total_terms];
178    for (fi, &idx) in free_indices.iter().enumerate() {
179        sigma[idx] = free_sigma[fi];
180    }
181    sigma
182}
183
184pub fn compute_sky_rms(residuals: &DVector<f64>, observations: &[&Observation]) -> f64 {
185    let n = observations.len();
186    if n == 0 {
187        return 0.0;
188    }
189    let mut sum_sq = 0.0;
190    for i in 0..n {
191        let dh = residuals[2 * i];
192        let dd = residuals[2 * i + 1];
193        let cos_dec = libm::cos(observations[i].catalog_dec.radians());
194        let dx = dh * cos_dec;
195        sum_sq += dx * dx + dd * dd;
196    }
197    libm::sqrt(sum_sq / n as f64)
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203    use crate::observation::PierSide;
204    use crate::terms::create_term;
205    use celestial_core::Angle;
206
207    fn make_obs(cmd_ha_arcsec: f64, act_ha_arcsec: f64, dec_deg: f64) -> Observation {
208        Observation {
209            catalog_ra: Angle::from_hours(0.0),
210            catalog_dec: Angle::from_degrees(dec_deg),
211            observed_ra: Angle::from_hours(0.0),
212            observed_dec: Angle::from_degrees(dec_deg),
213            lst: Angle::from_hours(0.0),
214            commanded_ha: Angle::from_arcseconds(cmd_ha_arcsec),
215            actual_ha: Angle::from_arcseconds(act_ha_arcsec),
216            pier_side: PierSide::East,
217            masked: false,
218        }
219    }
220
221    #[test]
222    fn fit_ih_recovers_known_coefficient() {
223        let obs = [
224            make_obs(0.0, 100.0, 30.0),
225            make_obs(0.0, 100.0, 45.0),
226            make_obs(0.0, 100.0, 60.0),
227        ];
228        let refs: Vec<&Observation> = obs.iter().collect();
229        let terms: Vec<Box<dyn Term>> = vec![create_term("IH").unwrap()];
230        let fixed = [false];
231        let coeffs = [0.0];
232        let result = fit_model(&refs, &terms, &fixed, &coeffs, 0.7).unwrap();
233        assert_eq!(result.term_names, vec!["IH"]);
234        assert!((result.coefficients[0] - (-100.0)).abs() < 1e-6);
235    }
236
237    #[test]
238    fn fit_insufficient_observations() {
239        let obs = [make_obs(0.0, 100.0, 30.0)];
240        let refs: Vec<&Observation> = obs.iter().collect();
241        let terms: Vec<Box<dyn Term>> =
242            vec![create_term("IH").unwrap(), create_term("ID").unwrap()];
243        let fixed = [false, false];
244        let coeffs = [0.0, 0.0];
245        let result = fit_model(&refs, &terms, &fixed, &coeffs, 0.7);
246        assert!(result.is_err());
247    }
248
249    #[test]
250    fn fit_no_terms() {
251        let obs = [make_obs(0.0, 100.0, 30.0)];
252        let refs: Vec<&Observation> = obs.iter().collect();
253        let terms: Vec<Box<dyn Term>> = vec![];
254        let fixed: [bool; 0] = [];
255        let coeffs: [f64; 0] = [];
256        let result = fit_model(&refs, &terms, &fixed, &coeffs, 0.7);
257        assert!(result.is_err());
258    }
259
260    #[test]
261    fn sky_rms_known_residuals() {
262        let obs = [make_obs(0.0, 0.0, 0.0), make_obs(0.0, 0.0, 0.0)];
263        let refs: Vec<&Observation> = obs.iter().collect();
264        let n = obs.len();
265        let mut residuals = DVector::zeros(2 * n);
266        residuals[0] = 3.0;
267        residuals[1] = 4.0;
268        residuals[2] = 3.0;
269        residuals[3] = 4.0;
270        let rms = compute_sky_rms(&residuals, &refs);
271        assert_eq!(rms, 5.0);
272    }
273}