Skip to main content

celestial_pointing/
model.rs

1use crate::error::{Error, Result};
2use crate::observation::PierSide;
3use crate::terms::{create_term, Term};
4use celestial_core::Angle;
5
6#[derive(Default)]
7pub struct PointingModel {
8    terms: Vec<Box<dyn Term>>,
9    coefficients: Vec<f64>,
10    fixed: Vec<bool>,
11    parallel: Vec<bool>,
12}
13
14impl PointingModel {
15    pub fn new() -> Self {
16        Self::default()
17    }
18
19    pub fn add_term(&mut self, name: &str) -> Result<()> {
20        let term = create_term(name)?;
21        self.terms.push(term);
22        self.coefficients.push(0.0);
23        self.fixed.push(false);
24        self.parallel.push(true);
25        Ok(())
26    }
27
28    pub fn remove_term(&mut self, name: &str) {
29        if let Some(idx) = self.terms.iter().position(|t| t.name() == name) {
30            self.terms.remove(idx);
31            self.coefficients.remove(idx);
32            self.fixed.remove(idx);
33            self.parallel.remove(idx);
34        }
35    }
36
37    pub fn remove_all(&mut self) {
38        self.terms.clear();
39        self.coefficients.clear();
40        self.fixed.clear();
41        self.parallel.clear();
42    }
43
44    pub fn fix_term(&mut self, name: &str) -> bool {
45        if let Some(idx) = self.terms.iter().position(|t| t.name() == name) {
46            self.fixed[idx] = true;
47            return true;
48        }
49        false
50    }
51
52    pub fn fix_all(&mut self) {
53        self.fixed.iter_mut().for_each(|f| *f = true);
54    }
55
56    pub fn unfix_term(&mut self, name: &str) -> bool {
57        if let Some(idx) = self.terms.iter().position(|t| t.name() == name) {
58            self.fixed[idx] = false;
59            return true;
60        }
61        false
62    }
63
64    pub fn unfix_all(&mut self) {
65        self.fixed.iter_mut().for_each(|f| *f = false);
66    }
67
68    pub fn is_fixed(&self, idx: usize) -> bool {
69        self.fixed.get(idx).copied().unwrap_or(false)
70    }
71
72    pub fn fixed_flags(&self) -> &[bool] {
73        &self.fixed
74    }
75
76    pub fn set_parallel(&mut self, name: &str) -> bool {
77        if let Some(idx) = self.terms.iter().position(|t| t.name() == name) {
78            self.parallel[idx] = true;
79            return true;
80        }
81        false
82    }
83
84    pub fn set_chained(&mut self, name: &str) -> bool {
85        if let Some(idx) = self.terms.iter().position(|t| t.name() == name) {
86            self.parallel[idx] = false;
87            return true;
88        }
89        false
90    }
91
92    pub fn set_all_parallel(&mut self) {
93        self.parallel.iter_mut().for_each(|p| *p = true);
94    }
95
96    pub fn set_all_chained(&mut self) {
97        self.parallel.iter_mut().for_each(|p| *p = false);
98    }
99
100    pub fn is_parallel(&self, idx: usize) -> bool {
101        self.parallel.get(idx).copied().unwrap_or(true)
102    }
103
104    pub fn zero_coefficients(&mut self) {
105        self.coefficients.iter_mut().for_each(|c| *c = 0.0);
106    }
107
108    pub fn term_count(&self) -> usize {
109        self.terms.len()
110    }
111
112    pub fn term_names(&self) -> Vec<&str> {
113        self.terms.iter().map(|t| t.name()).collect()
114    }
115
116    pub fn terms(&self) -> &[Box<dyn Term>] {
117        &self.terms
118    }
119
120    pub fn coefficients(&self) -> &[f64] {
121        &self.coefficients
122    }
123
124    pub fn set_coefficients(&mut self, coeffs: &[f64]) -> Result<()> {
125        if coeffs.len() != self.terms.len() {
126            return Err(Error::Fit(format!(
127                "coefficient count {} does not match term count {}",
128                coeffs.len(),
129                self.terms.len()
130            )));
131        }
132        self.coefficients.copy_from_slice(coeffs);
133        Ok(())
134    }
135
136    pub fn apply_equatorial(&self, h: f64, dec: f64, lat: f64, pier: f64) -> (f64, f64) {
137        let mut dh = 0.0;
138        let mut ddec = 0.0;
139        for (term, &coeff) in self.terms.iter().zip(self.coefficients.iter()) {
140            let (jh, jd) = term.jacobian_equatorial(h, dec, lat, pier);
141            dh += coeff * jh;
142            ddec += coeff * jd;
143        }
144        (dh, ddec)
145    }
146
147    pub fn apply_altaz(&self, az: f64, el: f64, lat: f64) -> (f64, f64) {
148        let mut daz = 0.0;
149        let mut del = 0.0;
150        for (term, &coeff) in self.terms.iter().zip(self.coefficients.iter()) {
151            let (ja, je) = term.jacobian_altaz(az, el, lat);
152            daz += coeff * ja;
153            del += coeff * je;
154        }
155        (daz, del)
156    }
157
158    pub fn apply_equatorial_chained(&self, h: f64, dec: f64, lat: f64, pier: f64) -> (f64, f64) {
159        let mut h_corr = h;
160        let mut dec_corr = dec;
161
162        for (i, term) in self.terms.iter().enumerate() {
163            if !self.parallel[i] {
164                let (jh, jd) = term.jacobian_equatorial(h_corr, dec_corr, lat, pier);
165                h_corr += self.coefficients[i] * jh;
166                dec_corr += self.coefficients[i] * jd;
167            }
168        }
169
170        let mut dh = 0.0;
171        let mut ddec = 0.0;
172        for (i, term) in self.terms.iter().enumerate() {
173            if self.parallel[i] {
174                let (jh, jd) = term.jacobian_equatorial(h_corr, dec_corr, lat, pier);
175                dh += self.coefficients[i] * jh;
176                ddec += self.coefficients[i] * jd;
177            }
178        }
179
180        (h_corr + dh - h, dec_corr + ddec - dec)
181    }
182
183    pub fn target_to_command(
184        &self,
185        ra: Angle,
186        dec: Angle,
187        lst: Angle,
188        lat: Angle,
189        pier: PierSide,
190    ) -> (Angle, Angle) {
191        let ha = lst - ra;
192        let (dh, dd) =
193            self.apply_equatorial(ha.radians(), dec.radians(), lat.radians(), pier.sign());
194        let cmd_ha = ha - Angle::from_arcseconds(dh);
195        let cmd_dec = dec - Angle::from_arcseconds(dd);
196        let cmd_ra = lst - cmd_ha;
197        (cmd_ra, cmd_dec)
198    }
199
200    pub fn command_to_target(
201        &self,
202        ra_encoder: Angle,
203        dec_encoder: Angle,
204        lst: Angle,
205        lat: Angle,
206        pier: PierSide,
207    ) -> (Angle, Angle) {
208        let ha = lst - ra_encoder;
209        let (dh, dd) = self.apply_equatorial(
210            ha.radians(),
211            dec_encoder.radians(),
212            lat.radians(),
213            pier.sign(),
214        );
215        let true_ha = ha + Angle::from_arcseconds(dh);
216        let true_dec = dec_encoder + Angle::from_arcseconds(dd);
217        let true_ra = lst - true_ha;
218        (true_ra, true_dec)
219    }
220
221    pub fn predict_breakdown(
222        &self,
223        h: f64,
224        dec: f64,
225        lat: f64,
226        pier: f64,
227    ) -> Vec<(String, f64, f64)> {
228        self.terms
229            .iter()
230            .zip(self.coefficients.iter())
231            .map(|(term, &coeff)| {
232                let (jh, jd) = term.jacobian_equatorial(h, dec, lat, pier);
233                (term.name().to_string(), coeff * jh, coeff * jd)
234            })
235            .collect()
236    }
237}
238
239#[cfg(test)]
240mod tests {
241    use super::*;
242    use std::f64::consts::FRAC_PI_4;
243
244    #[test]
245    fn apply_equatorial_ih_id() {
246        let mut model = PointingModel::new();
247        model.add_term("IH").unwrap();
248        model.add_term("ID").unwrap();
249        model.set_coefficients(&[10.0, 20.0]).unwrap();
250
251        let (dh, ddec) = model.apply_equatorial(FRAC_PI_4, 0.5, 0.7, 1.0);
252        assert_eq!(dh, -10.0);
253        assert_eq!(ddec, -20.0);
254    }
255
256    #[test]
257    fn apply_equatorial_id_west_pier() {
258        let mut model = PointingModel::new();
259        model.add_term("ID").unwrap();
260        model.set_coefficients(&[20.0]).unwrap();
261
262        let (_, ddec) = model.apply_equatorial(0.0, 0.0, 0.0, -1.0);
263        assert_eq!(ddec, 20.0 * 1.0);
264    }
265
266    #[test]
267    fn add_remove_terms() {
268        let mut model = PointingModel::new();
269        model.add_term("IH").unwrap();
270        model.add_term("ID").unwrap();
271        model.add_term("CH").unwrap();
272        assert_eq!(model.term_count(), 3);
273        assert_eq!(model.term_names(), vec!["IH", "ID", "CH"]);
274
275        model.remove_term("ID");
276        assert_eq!(model.term_count(), 2);
277        assert_eq!(model.term_names(), vec!["IH", "CH"]);
278        assert_eq!(model.coefficients().len(), 2);
279    }
280
281    #[test]
282    fn remove_all_clears_model() {
283        let mut model = PointingModel::new();
284        model.add_term("IH").unwrap();
285        model.add_term("ID").unwrap();
286        model.remove_all();
287        assert_eq!(model.term_count(), 0);
288        assert_eq!(model.coefficients().len(), 0);
289    }
290
291    #[test]
292    fn remove_nonexistent_is_noop() {
293        let mut model = PointingModel::new();
294        model.add_term("IH").unwrap();
295        model.remove_term("ZZZZ");
296        assert_eq!(model.term_count(), 1);
297    }
298
299    #[test]
300    fn set_coefficients_wrong_length() {
301        let mut model = PointingModel::new();
302        model.add_term("IH").unwrap();
303        let result = model.set_coefficients(&[1.0, 2.0]);
304        assert!(result.is_err());
305    }
306
307    #[test]
308    fn empty_model_returns_zero_correction() {
309        let model = PointingModel::new();
310        let (dh, ddec) = model.apply_equatorial(1.0, 0.5, 0.7, 1.0);
311        assert_eq!(dh, 0.0);
312        assert_eq!(ddec, 0.0);
313    }
314
315    #[test]
316    fn add_unknown_term_returns_error() {
317        let mut model = PointingModel::new();
318        let result = model.add_term("ZZZZ");
319        assert!(result.is_err());
320    }
321}