1use crate::error::{Error, Result};
2use crate::observation::Observation;
3use crate::terms::Term;
4use nalgebra::{DMatrix, DVector};
5
6#[derive(Clone)]
7pub struct FitResult {
8 pub coefficients: Vec<f64>,
9 pub sigma: Vec<f64>,
10 pub sky_rms: f64,
11 pub term_names: Vec<String>,
12}
13
14pub fn fit_model(
15 observations: &[&Observation],
16 terms: &[Box<dyn Term>],
17 fixed: &[bool],
18 coefficients: &[f64],
19 latitude: f64,
20) -> Result<FitResult> {
21 let free_count = fixed.iter().filter(|&&f| !f).count();
22 if free_count == 0 && !terms.is_empty() {
23 return Err(Error::Fit("all terms are fixed".into()));
24 }
25 if terms.is_empty() {
26 return Err(Error::Fit("no terms to fit".into()));
27 }
28 if observations.len() < free_count {
29 return Err(Error::Fit("insufficient observations".into()));
30 }
31
32 let free_indices: Vec<usize> = fixed
33 .iter()
34 .enumerate()
35 .filter(|(_, &f)| !f)
36 .map(|(i, _)| i)
37 .collect();
38 let fixed_indices: Vec<usize> = fixed
39 .iter()
40 .enumerate()
41 .filter(|(_, &f)| f)
42 .map(|(i, _)| i)
43 .collect();
44
45 let mut b = build_residuals(observations);
46 let w = build_weights(observations);
47 let a_full = build_design_matrix(observations, terms, latitude);
48
49 subtract_fixed_contributions(&mut b, &a_full, coefficients, &fixed_indices);
50
51 let a_free = extract_columns(&a_full, &free_indices);
52 let free_coeffs = solve_weighted(&a_free, &b, &w)?;
53 let free_residuals = &b - &a_free * &free_coeffs;
54
55 let mut all_coeffs = vec![0.0; terms.len()];
56 for (fi, &idx) in free_indices.iter().enumerate() {
57 all_coeffs[idx] = free_coeffs[fi];
58 }
59
60 let full_residuals = build_residuals(observations);
61 let full_a = &a_full;
62 let all_coeffs_dv = DVector::from_vec(all_coeffs.clone());
63 let actual_residuals = &full_residuals - full_a * &all_coeffs_dv;
64
65 let sigma = compute_sigma_free(&a_free, &free_residuals, &w, &free_indices, terms.len());
66 let sky_rms = compute_sky_rms(&actual_residuals, observations);
67 let term_names = terms.iter().map(|t| t.name().to_string()).collect();
68
69 Ok(FitResult {
70 coefficients: all_coeffs,
71 sigma,
72 sky_rms,
73 term_names,
74 })
75}
76
77fn subtract_fixed_contributions(
78 b: &mut DVector<f64>,
79 a: &DMatrix<f64>,
80 coefficients: &[f64],
81 fixed_indices: &[usize],
82) {
83 for &idx in fixed_indices {
84 let coeff = coefficients[idx];
85 for row in 0..a.nrows() {
86 b[row] -= a[(row, idx)] * coeff;
87 }
88 }
89}
90
91fn extract_columns(a: &DMatrix<f64>, cols: &[usize]) -> DMatrix<f64> {
92 let rows = a.nrows();
93 let m = cols.len();
94 let mut out = DMatrix::zeros(rows, m);
95 for (j, &col) in cols.iter().enumerate() {
96 for i in 0..rows {
97 out[(i, j)] = a[(i, col)];
98 }
99 }
100 out
101}
102
103pub fn build_residuals(observations: &[&Observation]) -> DVector<f64> {
104 let n = observations.len();
105 let mut b = DVector::zeros(2 * n);
106 for (i, obs) in observations.iter().enumerate() {
107 b[2 * i] = (obs.actual_ha - obs.commanded_ha).arcseconds();
108 b[2 * i + 1] = (obs.observed_dec - obs.catalog_dec).arcseconds();
109 }
110 b
111}
112
113fn build_weights(observations: &[&Observation]) -> DVector<f64> {
114 let n = observations.len();
115 let mut w = DVector::zeros(2 * n);
116 for (i, obs) in observations.iter().enumerate() {
117 let cos_dec = libm::cos(obs.catalog_dec.radians());
118 w[2 * i] = cos_dec * cos_dec;
119 w[2 * i + 1] = 1.0;
120 }
121 w
122}
123
124fn build_design_matrix(
125 observations: &[&Observation],
126 terms: &[Box<dyn Term>],
127 lat: f64,
128) -> DMatrix<f64> {
129 let n = observations.len();
130 let m = terms.len();
131 let mut a = DMatrix::zeros(2 * n, m);
132 for (i, obs) in observations.iter().enumerate() {
133 let h = obs.commanded_ha.radians();
134 let dec = obs.catalog_dec.radians();
135 let pier = obs.pier_side.sign();
136 for (j, term) in terms.iter().enumerate() {
137 let (jh, jd) = term.jacobian_equatorial(h, dec, lat, pier);
138 a[(2 * i, j)] = jh;
139 a[(2 * i + 1, j)] = jd;
140 }
141 }
142 a
143}
144
145fn solve_weighted(a: &DMatrix<f64>, b: &DVector<f64>, w: &DVector<f64>) -> Result<DVector<f64>> {
146 let sqrt_w = w.map(libm::sqrt);
147 let rows = a.nrows();
148 let cols = a.ncols();
149 let a_w = DMatrix::from_fn(rows, cols, |i, j| a[(i, j)] * sqrt_w[i]);
150 let b_w = DVector::from_fn(rows, |i, _| b[i] * sqrt_w[i]);
151 let svd = a_w.svd(true, true);
152 svd.solve(&b_w, 1e-10)
153 .map_err(|e| Error::Fit(format!("SVD solve failed: {}", e)))
154}
155
156fn compute_sigma_free(
157 a_free: &DMatrix<f64>,
158 residuals: &DVector<f64>,
159 w: &DVector<f64>,
160 free_indices: &[usize],
161 total_terms: usize,
162) -> Vec<f64> {
163 let n = a_free.nrows();
164 let m = a_free.ncols();
165 let dof = n.saturating_sub(m).max(1);
166 let sqrt_w = w.map(libm::sqrt);
167 let a_w = DMatrix::from_fn(n, m, |i, j| a_free[(i, j)] * sqrt_w[i]);
168 let r_w = DVector::from_fn(n, |i, _| residuals[i] * sqrt_w[i]);
169 let s2 = r_w.dot(&r_w) / dof as f64;
170 let ata = a_w.transpose() * &a_w;
171 let free_sigma = match ata.try_inverse() {
172 Some(inv) => (0..m)
173 .map(|j| libm::sqrt((s2 * inv[(j, j)]).abs()))
174 .collect::<Vec<_>>(),
175 None => vec![f64::NAN; m],
176 };
177 let mut sigma = vec![0.0; total_terms];
178 for (fi, &idx) in free_indices.iter().enumerate() {
179 sigma[idx] = free_sigma[fi];
180 }
181 sigma
182}
183
184pub fn compute_sky_rms(residuals: &DVector<f64>, observations: &[&Observation]) -> f64 {
185 let n = observations.len();
186 if n == 0 {
187 return 0.0;
188 }
189 let mut sum_sq = 0.0;
190 for i in 0..n {
191 let dh = residuals[2 * i];
192 let dd = residuals[2 * i + 1];
193 let cos_dec = libm::cos(observations[i].catalog_dec.radians());
194 let dx = dh * cos_dec;
195 sum_sq += dx * dx + dd * dd;
196 }
197 libm::sqrt(sum_sq / n as f64)
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203 use crate::observation::PierSide;
204 use crate::terms::create_term;
205 use celestial_core::Angle;
206
207 fn make_obs(cmd_ha_arcsec: f64, act_ha_arcsec: f64, dec_deg: f64) -> Observation {
208 Observation {
209 catalog_ra: Angle::from_hours(0.0),
210 catalog_dec: Angle::from_degrees(dec_deg),
211 observed_ra: Angle::from_hours(0.0),
212 observed_dec: Angle::from_degrees(dec_deg),
213 lst: Angle::from_hours(0.0),
214 commanded_ha: Angle::from_arcseconds(cmd_ha_arcsec),
215 actual_ha: Angle::from_arcseconds(act_ha_arcsec),
216 pier_side: PierSide::East,
217 masked: false,
218 }
219 }
220
221 #[test]
222 fn fit_ih_recovers_known_coefficient() {
223 let obs = [
224 make_obs(0.0, 100.0, 30.0),
225 make_obs(0.0, 100.0, 45.0),
226 make_obs(0.0, 100.0, 60.0),
227 ];
228 let refs: Vec<&Observation> = obs.iter().collect();
229 let terms: Vec<Box<dyn Term>> = vec![create_term("IH").unwrap()];
230 let fixed = [false];
231 let coeffs = [0.0];
232 let result = fit_model(&refs, &terms, &fixed, &coeffs, 0.7).unwrap();
233 assert_eq!(result.term_names, vec!["IH"]);
234 assert!((result.coefficients[0] - (-100.0)).abs() < 1e-6);
235 }
236
237 #[test]
238 fn fit_insufficient_observations() {
239 let obs = [make_obs(0.0, 100.0, 30.0)];
240 let refs: Vec<&Observation> = obs.iter().collect();
241 let terms: Vec<Box<dyn Term>> =
242 vec![create_term("IH").unwrap(), create_term("ID").unwrap()];
243 let fixed = [false, false];
244 let coeffs = [0.0, 0.0];
245 let result = fit_model(&refs, &terms, &fixed, &coeffs, 0.7);
246 assert!(result.is_err());
247 }
248
249 #[test]
250 fn fit_no_terms() {
251 let obs = [make_obs(0.0, 100.0, 30.0)];
252 let refs: Vec<&Observation> = obs.iter().collect();
253 let terms: Vec<Box<dyn Term>> = vec![];
254 let fixed: [bool; 0] = [];
255 let coeffs: [f64; 0] = [];
256 let result = fit_model(&refs, &terms, &fixed, &coeffs, 0.7);
257 assert!(result.is_err());
258 }
259
260 #[test]
261 fn sky_rms_known_residuals() {
262 let obs = [make_obs(0.0, 0.0, 0.0), make_obs(0.0, 0.0, 0.0)];
263 let refs: Vec<&Observation> = obs.iter().collect();
264 let n = obs.len();
265 let mut residuals = DVector::zeros(2 * n);
266 residuals[0] = 3.0;
267 residuals[1] = 4.0;
268 residuals[2] = 3.0;
269 residuals[3] = 4.0;
270 let rms = compute_sky_rms(&residuals, &refs);
271 assert_eq!(rms, 5.0);
272 }
273}