celestial_pointing/commands/
optimal.rs1use super::{Command, CommandOutput};
2use crate::error::{Error, Result};
3use crate::observation::Observation;
4use crate::session::Session;
5use crate::solver::{fit_model, FitResult};
6use crate::terms::create_term;
7use rayon::prelude::*;
8
9const BASE_TERMS: &[&str] = &["IH", "ID", "CH", "NP", "MA", "ME"];
10const PHYSICAL_CANDIDATES: &[&str] = &["TF", "TX", "DAF", "FO", "HCES", "HCEC", "DCES", "DCEC"];
11const DEFAULT_MAX_TERMS: usize = 30;
12const DEFAULT_BIC_THRESHOLD: f64 = -6.0;
13const MIN_SIGNIFICANCE: f64 = 2.0;
14
15struct StageEntry {
16 name: String,
17 delta_bic: f64,
18 rms: f64,
19}
20
21pub struct Optimal;
22
23impl Command for Optimal {
24 fn name(&self) -> &str {
25 "OPTIMAL"
26 }
27 fn description(&self) -> &str {
28 "Auto-build optimal model using BIC"
29 }
30
31 fn execute(&self, session: &mut Session, args: &[&str]) -> Result<CommandOutput> {
32 let (max_terms, bic_threshold) = parse_args(args)?;
33 let observations = active_observations(&session.observations);
34 let n_obs = observations.len();
35 if n_obs < BASE_TERMS.len() {
36 return Err(Error::Fit("insufficient observations for OPTIMAL".into()));
37 }
38 let latitude = session.latitude();
39
40 let mut report = String::from("OPTIMAL model search...\n");
41 let mut active: Vec<String> = BASE_TERMS.iter().map(|s| s.to_string()).collect();
42
43 let base_fit = try_fit(&observations, &active, latitude)?;
44 let mut current_bic = compute_bic(n_obs, active.len(), base_fit.sky_rms);
45 append_base_report(&mut report, &active, current_bic, base_fit.sky_rms);
46
47 let mut stage_log = Vec::new();
48 current_bic = run_physical_stage(
49 &observations,
50 &mut active,
51 current_bic,
52 bic_threshold,
53 latitude,
54 &mut stage_log,
55 )?;
56 let _final_stage_bic = run_harmonic_stage(
57 &observations,
58 &mut active,
59 current_bic,
60 bic_threshold,
61 max_terms,
62 latitude,
63 &mut stage_log,
64 )?;
65
66 for entry in &stage_log {
67 report.push_str(&format!(
68 "+ {} (dBIC={:.1}, RMS={:.2}\")\n",
69 entry.name, entry.delta_bic, entry.rms,
70 ));
71 }
72
73 let pruned = prune_terms(&observations, &mut active, latitude)?;
74 for name in &pruned {
75 report.push_str(&format!(
76 "- {} (pruned, significance < {:.1})\n",
77 name, MIN_SIGNIFICANCE
78 ));
79 }
80
81 let final_fit = try_fit(&observations, &active, latitude)?;
82 let _final_bic = compute_bic(n_obs, active.len(), final_fit.sky_rms);
83 append_final_report(&mut report, &active, &final_fit);
84 load_into_session(session, &active, &final_fit)?;
85
86 Ok(CommandOutput::Text(report))
87 }
88}
89
90fn parse_args(args: &[&str]) -> Result<(usize, f64)> {
91 let max_terms = match args.first() {
92 Some(s) => s
93 .parse::<usize>()
94 .map_err(|e| Error::Parse(format!("invalid max_terms: {}", e)))?,
95 None => DEFAULT_MAX_TERMS,
96 };
97 let bic_threshold = match args.get(1) {
98 Some(s) => s
99 .parse::<f64>()
100 .map_err(|e| Error::Parse(format!("invalid bic_threshold: {}", e)))?,
101 None => DEFAULT_BIC_THRESHOLD,
102 };
103 Ok((max_terms, bic_threshold))
104}
105
106fn active_observations(observations: &[Observation]) -> Vec<&Observation> {
107 observations.iter().filter(|o| !o.masked).collect()
108}
109
110fn compute_bic(n_obs: usize, n_terms: usize, sky_rms: f64) -> f64 {
111 let n = n_obs as f64;
112 let k = n_terms as f64;
113 let weighted_rss = sky_rms * sky_rms * n;
114 n * libm::log(weighted_rss / n) + k * libm::log(n)
115}
116
117fn try_fit(
118 observations: &[&Observation],
119 term_names: &[String],
120 latitude: f64,
121) -> Result<FitResult> {
122 let terms: Vec<_> = term_names
123 .iter()
124 .map(|n| create_term(n))
125 .collect::<Result<Vec<_>>>()?;
126 let fixed = vec![false; terms.len()];
127 let coeffs = vec![0.0; terms.len()];
128 fit_model(observations, &terms, &fixed, &coeffs, latitude)
129}
130
131fn run_physical_stage(
132 observations: &[&Observation],
133 active: &mut Vec<String>,
134 mut current_bic: f64,
135 threshold: f64,
136 latitude: f64,
137 log: &mut Vec<StageEntry>,
138) -> Result<f64> {
139 let n_obs = observations.len();
140 for &candidate in PHYSICAL_CANDIDATES {
141 if active.len() >= DEFAULT_MAX_TERMS {
142 break;
143 }
144 let mut trial = active.clone();
145 trial.push(candidate.to_string());
146 let fit = match try_fit(observations, &trial, latitude) {
147 Ok(f) => f,
148 Err(_) => continue,
149 };
150 let trial_bic = compute_bic(n_obs, trial.len(), fit.sky_rms);
151 let delta = trial_bic - current_bic;
152 if delta < threshold {
153 log.push(StageEntry {
154 name: candidate.to_string(),
155 delta_bic: delta,
156 rms: fit.sky_rms,
157 });
158 active.push(candidate.to_string());
159 current_bic = trial_bic;
160 }
161 }
162 Ok(current_bic)
163}
164
165fn generate_harmonic_candidates() -> Vec<String> {
166 let results = ["H", "D", "X"];
167 let funcs = ["S", "C"];
168 let coords = ["H", "D"];
169 let mut candidates = Vec::with_capacity(96);
170 for r in &results {
171 for f in &funcs {
172 for c in &coords {
173 for n in 1..=8u8 {
174 let suffix = if n == 1 { String::new() } else { n.to_string() };
175 candidates.push(format!("H{}{}{}{}", r, f, c, suffix));
176 }
177 }
178 }
179 }
180 candidates
181}
182
183fn run_harmonic_stage(
184 observations: &[&Observation],
185 active: &mut Vec<String>,
186 mut current_bic: f64,
187 threshold: f64,
188 max_terms: usize,
189 latitude: f64,
190 log: &mut Vec<StageEntry>,
191) -> Result<f64> {
192 let all_candidates = generate_harmonic_candidates();
193 let n_obs = observations.len();
194
195 loop {
196 if active.len() >= max_terms {
197 break;
198 }
199 let candidates: Vec<&String> = all_candidates
200 .iter()
201 .filter(|c| !active.contains(c))
202 .collect();
203 if candidates.is_empty() {
204 break;
205 }
206
207 let best = find_best_harmonic(observations, active, &candidates, n_obs, latitude);
208 match best {
209 Some((name, bic, rms)) => {
210 let delta = bic - current_bic;
211 if delta < threshold {
212 log.push(StageEntry {
213 name: name.clone(),
214 delta_bic: delta,
215 rms,
216 });
217 active.push(name);
218 current_bic = bic;
219 } else {
220 break;
221 }
222 }
223 None => break,
224 }
225 }
226 Ok(current_bic)
227}
228
229fn find_best_harmonic(
230 observations: &[&Observation],
231 active: &[String],
232 candidates: &[&String],
233 n_obs: usize,
234 latitude: f64,
235) -> Option<(String, f64, f64)> {
236 candidates
237 .par_iter()
238 .filter_map(|candidate| {
239 let mut trial = active.to_vec();
240 trial.push((*candidate).clone());
241 let fit = try_fit(observations, &trial, latitude).ok()?;
242 let bic = compute_bic(n_obs, trial.len(), fit.sky_rms);
243 Some(((*candidate).clone(), bic, fit.sky_rms))
244 })
245 .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
246}
247
248fn prune_terms(
249 observations: &[&Observation],
250 active: &mut Vec<String>,
251 latitude: f64,
252) -> Result<Vec<String>> {
253 let fit = try_fit(observations, active, latitude)?;
254 let mut pruned = Vec::new();
255 let base_set: Vec<String> = BASE_TERMS.iter().map(|s| s.to_string()).collect();
256
257 let to_remove: Vec<String> = active
258 .iter()
259 .enumerate()
260 .filter(|(i, name)| {
261 if base_set.contains(name) {
262 return false;
263 }
264 let sigma = fit.sigma[*i];
265 if sigma == 0.0 {
266 return false;
267 }
268 (fit.coefficients[*i] / sigma).abs() < MIN_SIGNIFICANCE
269 })
270 .map(|(_, name)| name.clone())
271 .collect();
272
273 for name in &to_remove {
274 active.retain(|n| n != name);
275 pruned.push(name.clone());
276 }
277 Ok(pruned)
278}
279
280fn load_into_session(session: &mut Session, active: &[String], result: &FitResult) -> Result<()> {
281 session.model.remove_all();
282 for name in active {
283 session.model.add_term(name)?;
284 }
285 session.model.set_coefficients(&result.coefficients)?;
286 session.last_fit = Some(result.clone());
287 Ok(())
288}
289
290fn append_base_report(report: &mut String, terms: &[String], bic: f64, rms: f64) {
291 report.push_str(&format!(
292 "Base: {} (BIC={:.1}, RMS={:.2}\")\n",
293 terms.join(" "),
294 bic,
295 rms,
296 ));
297}
298
299fn append_final_report(report: &mut String, terms: &[String], fit: &FitResult) {
300 report.push_str(&format!(
301 "\nFinal model: {} terms, RMS={:.2}\"\n",
302 terms.len(),
303 fit.sky_rms,
304 ));
305 report.push_str("Terms: ");
306 report.push_str(&terms.join(" "));
307 report.push('\n');
308}