1use crate::error::{Error, Result};
2use crate::observation::PierSide;
3use crate::terms::{create_term, Term};
4use celestial_core::Angle;
5
6#[derive(Default)]
7pub struct PointingModel {
8 terms: Vec<Box<dyn Term>>,
9 coefficients: Vec<f64>,
10 fixed: Vec<bool>,
11 parallel: Vec<bool>,
12}
13
14impl PointingModel {
15 pub fn new() -> Self {
16 Self::default()
17 }
18
19 pub fn add_term(&mut self, name: &str) -> Result<()> {
20 let term = create_term(name)?;
21 self.terms.push(term);
22 self.coefficients.push(0.0);
23 self.fixed.push(false);
24 self.parallel.push(true);
25 Ok(())
26 }
27
28 pub fn remove_term(&mut self, name: &str) {
29 if let Some(idx) = self.terms.iter().position(|t| t.name() == name) {
30 self.terms.remove(idx);
31 self.coefficients.remove(idx);
32 self.fixed.remove(idx);
33 self.parallel.remove(idx);
34 }
35 }
36
37 pub fn remove_all(&mut self) {
38 self.terms.clear();
39 self.coefficients.clear();
40 self.fixed.clear();
41 self.parallel.clear();
42 }
43
44 pub fn fix_term(&mut self, name: &str) -> bool {
45 if let Some(idx) = self.terms.iter().position(|t| t.name() == name) {
46 self.fixed[idx] = true;
47 return true;
48 }
49 false
50 }
51
52 pub fn fix_all(&mut self) {
53 self.fixed.iter_mut().for_each(|f| *f = true);
54 }
55
56 pub fn unfix_term(&mut self, name: &str) -> bool {
57 if let Some(idx) = self.terms.iter().position(|t| t.name() == name) {
58 self.fixed[idx] = false;
59 return true;
60 }
61 false
62 }
63
64 pub fn unfix_all(&mut self) {
65 self.fixed.iter_mut().for_each(|f| *f = false);
66 }
67
68 pub fn is_fixed(&self, idx: usize) -> bool {
69 self.fixed.get(idx).copied().unwrap_or(false)
70 }
71
72 pub fn fixed_flags(&self) -> &[bool] {
73 &self.fixed
74 }
75
76 pub fn set_parallel(&mut self, name: &str) -> bool {
77 if let Some(idx) = self.terms.iter().position(|t| t.name() == name) {
78 self.parallel[idx] = true;
79 return true;
80 }
81 false
82 }
83
84 pub fn set_chained(&mut self, name: &str) -> bool {
85 if let Some(idx) = self.terms.iter().position(|t| t.name() == name) {
86 self.parallel[idx] = false;
87 return true;
88 }
89 false
90 }
91
92 pub fn set_all_parallel(&mut self) {
93 self.parallel.iter_mut().for_each(|p| *p = true);
94 }
95
96 pub fn set_all_chained(&mut self) {
97 self.parallel.iter_mut().for_each(|p| *p = false);
98 }
99
100 pub fn is_parallel(&self, idx: usize) -> bool {
101 self.parallel.get(idx).copied().unwrap_or(true)
102 }
103
104 pub fn zero_coefficients(&mut self) {
105 self.coefficients.iter_mut().for_each(|c| *c = 0.0);
106 }
107
108 pub fn term_count(&self) -> usize {
109 self.terms.len()
110 }
111
112 pub fn term_names(&self) -> Vec<&str> {
113 self.terms.iter().map(|t| t.name()).collect()
114 }
115
116 pub fn terms(&self) -> &[Box<dyn Term>] {
117 &self.terms
118 }
119
120 pub fn coefficients(&self) -> &[f64] {
121 &self.coefficients
122 }
123
124 pub fn set_coefficients(&mut self, coeffs: &[f64]) -> Result<()> {
125 if coeffs.len() != self.terms.len() {
126 return Err(Error::Fit(format!(
127 "coefficient count {} does not match term count {}",
128 coeffs.len(),
129 self.terms.len()
130 )));
131 }
132 self.coefficients.copy_from_slice(coeffs);
133 Ok(())
134 }
135
136 pub fn apply_equatorial(&self, h: f64, dec: f64, lat: f64, pier: f64) -> (f64, f64) {
137 let mut dh = 0.0;
138 let mut ddec = 0.0;
139 for (term, &coeff) in self.terms.iter().zip(self.coefficients.iter()) {
140 let (jh, jd) = term.jacobian_equatorial(h, dec, lat, pier);
141 dh += coeff * jh;
142 ddec += coeff * jd;
143 }
144 (dh, ddec)
145 }
146
147 pub fn apply_altaz(&self, az: f64, el: f64, lat: f64) -> (f64, f64) {
148 let mut daz = 0.0;
149 let mut del = 0.0;
150 for (term, &coeff) in self.terms.iter().zip(self.coefficients.iter()) {
151 let (ja, je) = term.jacobian_altaz(az, el, lat);
152 daz += coeff * ja;
153 del += coeff * je;
154 }
155 (daz, del)
156 }
157
158 pub fn apply_equatorial_chained(&self, h: f64, dec: f64, lat: f64, pier: f64) -> (f64, f64) {
159 let mut h_corr = h;
160 let mut dec_corr = dec;
161
162 for (i, term) in self.terms.iter().enumerate() {
163 if !self.parallel[i] {
164 let (jh, jd) = term.jacobian_equatorial(h_corr, dec_corr, lat, pier);
165 h_corr += self.coefficients[i] * jh;
166 dec_corr += self.coefficients[i] * jd;
167 }
168 }
169
170 let mut dh = 0.0;
171 let mut ddec = 0.0;
172 for (i, term) in self.terms.iter().enumerate() {
173 if self.parallel[i] {
174 let (jh, jd) = term.jacobian_equatorial(h_corr, dec_corr, lat, pier);
175 dh += self.coefficients[i] * jh;
176 ddec += self.coefficients[i] * jd;
177 }
178 }
179
180 (h_corr + dh - h, dec_corr + ddec - dec)
181 }
182
183 pub fn target_to_command(
184 &self,
185 ra: Angle,
186 dec: Angle,
187 lst: Angle,
188 lat: Angle,
189 pier: PierSide,
190 ) -> (Angle, Angle) {
191 let ha = lst - ra;
192 let (dh, dd) =
193 self.apply_equatorial(ha.radians(), dec.radians(), lat.radians(), pier.sign());
194 let cmd_ha = ha - Angle::from_arcseconds(dh);
195 let cmd_dec = dec - Angle::from_arcseconds(dd);
196 let cmd_ra = lst - cmd_ha;
197 (cmd_ra, cmd_dec)
198 }
199
200 pub fn command_to_target(
201 &self,
202 ra_encoder: Angle,
203 dec_encoder: Angle,
204 lst: Angle,
205 lat: Angle,
206 pier: PierSide,
207 ) -> (Angle, Angle) {
208 let ha = lst - ra_encoder;
209 let (dh, dd) = self.apply_equatorial(
210 ha.radians(),
211 dec_encoder.radians(),
212 lat.radians(),
213 pier.sign(),
214 );
215 let true_ha = ha + Angle::from_arcseconds(dh);
216 let true_dec = dec_encoder + Angle::from_arcseconds(dd);
217 let true_ra = lst - true_ha;
218 (true_ra, true_dec)
219 }
220
221 pub fn predict_breakdown(
222 &self,
223 h: f64,
224 dec: f64,
225 lat: f64,
226 pier: f64,
227 ) -> Vec<(String, f64, f64)> {
228 self.terms
229 .iter()
230 .zip(self.coefficients.iter())
231 .map(|(term, &coeff)| {
232 let (jh, jd) = term.jacobian_equatorial(h, dec, lat, pier);
233 (term.name().to_string(), coeff * jh, coeff * jd)
234 })
235 .collect()
236 }
237}
238
239#[cfg(test)]
240mod tests {
241 use super::*;
242 use std::f64::consts::FRAC_PI_4;
243
244 #[test]
245 fn apply_equatorial_ih_id() {
246 let mut model = PointingModel::new();
247 model.add_term("IH").unwrap();
248 model.add_term("ID").unwrap();
249 model.set_coefficients(&[10.0, 20.0]).unwrap();
250
251 let (dh, ddec) = model.apply_equatorial(FRAC_PI_4, 0.5, 0.7, 1.0);
252 assert_eq!(dh, -10.0);
253 assert_eq!(ddec, -20.0);
254 }
255
256 #[test]
257 fn apply_equatorial_id_west_pier() {
258 let mut model = PointingModel::new();
259 model.add_term("ID").unwrap();
260 model.set_coefficients(&[20.0]).unwrap();
261
262 let (_, ddec) = model.apply_equatorial(0.0, 0.0, 0.0, -1.0);
263 assert_eq!(ddec, 20.0 * 1.0);
264 }
265
266 #[test]
267 fn add_remove_terms() {
268 let mut model = PointingModel::new();
269 model.add_term("IH").unwrap();
270 model.add_term("ID").unwrap();
271 model.add_term("CH").unwrap();
272 assert_eq!(model.term_count(), 3);
273 assert_eq!(model.term_names(), vec!["IH", "ID", "CH"]);
274
275 model.remove_term("ID");
276 assert_eq!(model.term_count(), 2);
277 assert_eq!(model.term_names(), vec!["IH", "CH"]);
278 assert_eq!(model.coefficients().len(), 2);
279 }
280
281 #[test]
282 fn remove_all_clears_model() {
283 let mut model = PointingModel::new();
284 model.add_term("IH").unwrap();
285 model.add_term("ID").unwrap();
286 model.remove_all();
287 assert_eq!(model.term_count(), 0);
288 assert_eq!(model.coefficients().len(), 0);
289 }
290
291 #[test]
292 fn remove_nonexistent_is_noop() {
293 let mut model = PointingModel::new();
294 model.add_term("IH").unwrap();
295 model.remove_term("ZZZZ");
296 assert_eq!(model.term_count(), 1);
297 }
298
299 #[test]
300 fn set_coefficients_wrong_length() {
301 let mut model = PointingModel::new();
302 model.add_term("IH").unwrap();
303 let result = model.set_coefficients(&[1.0, 2.0]);
304 assert!(result.is_err());
305 }
306
307 #[test]
308 fn empty_model_returns_zero_correction() {
309 let model = PointingModel::new();
310 let (dh, ddec) = model.apply_equatorial(1.0, 0.5, 0.7, 1.0);
311 assert_eq!(dh, 0.0);
312 assert_eq!(ddec, 0.0);
313 }
314
315 #[test]
316 fn add_unknown_term_returns_error() {
317 let mut model = PointingModel::new();
318 let result = model.add_term("ZZZZ");
319 assert!(result.is_err());
320 }
321}