Skip to main content

celestial_pointing/commands/
optimal.rs

1use super::{Command, CommandOutput};
2use crate::error::{Error, Result};
3use crate::observation::Observation;
4use crate::session::Session;
5use crate::solver::{fit_model, FitResult};
6use crate::terms::create_term;
7use rayon::prelude::*;
8
9const BASE_TERMS: &[&str] = &["IH", "ID", "CH", "NP", "MA", "ME"];
10const PHYSICAL_CANDIDATES: &[&str] = &["TF", "TX", "DAF", "FO", "HCES", "HCEC", "DCES", "DCEC"];
11const DEFAULT_MAX_TERMS: usize = 30;
12const DEFAULT_BIC_THRESHOLD: f64 = -6.0;
13const MIN_SIGNIFICANCE: f64 = 2.0;
14
15struct StageEntry {
16    name: String,
17    delta_bic: f64,
18    rms: f64,
19}
20
21pub struct Optimal;
22
23impl Command for Optimal {
24    fn name(&self) -> &str {
25        "OPTIMAL"
26    }
27    fn description(&self) -> &str {
28        "Auto-build optimal model using BIC"
29    }
30
31    fn execute(&self, session: &mut Session, args: &[&str]) -> Result<CommandOutput> {
32        let (max_terms, bic_threshold) = parse_args(args)?;
33        let observations = active_observations(&session.observations);
34        let n_obs = observations.len();
35        if n_obs < BASE_TERMS.len() {
36            return Err(Error::Fit("insufficient observations for OPTIMAL".into()));
37        }
38        let latitude = session.latitude();
39
40        let mut report = String::from("OPTIMAL model search...\n");
41        let mut active: Vec<String> = BASE_TERMS.iter().map(|s| s.to_string()).collect();
42
43        let base_fit = try_fit(&observations, &active, latitude)?;
44        let mut current_bic = compute_bic(n_obs, active.len(), base_fit.sky_rms);
45        append_base_report(&mut report, &active, current_bic, base_fit.sky_rms);
46
47        let mut stage_log = Vec::new();
48        current_bic = run_physical_stage(
49            &observations,
50            &mut active,
51            current_bic,
52            bic_threshold,
53            latitude,
54            &mut stage_log,
55        )?;
56        let _final_stage_bic = run_harmonic_stage(
57            &observations,
58            &mut active,
59            current_bic,
60            bic_threshold,
61            max_terms,
62            latitude,
63            &mut stage_log,
64        )?;
65
66        for entry in &stage_log {
67            report.push_str(&format!(
68                "+ {} (dBIC={:.1}, RMS={:.2}\")\n",
69                entry.name, entry.delta_bic, entry.rms,
70            ));
71        }
72
73        let pruned = prune_terms(&observations, &mut active, latitude)?;
74        for name in &pruned {
75            report.push_str(&format!(
76                "- {} (pruned, significance < {:.1})\n",
77                name, MIN_SIGNIFICANCE
78            ));
79        }
80
81        let final_fit = try_fit(&observations, &active, latitude)?;
82        let _final_bic = compute_bic(n_obs, active.len(), final_fit.sky_rms);
83        append_final_report(&mut report, &active, &final_fit);
84        load_into_session(session, &active, &final_fit)?;
85
86        Ok(CommandOutput::Text(report))
87    }
88}
89
90fn parse_args(args: &[&str]) -> Result<(usize, f64)> {
91    let max_terms = match args.first() {
92        Some(s) => s
93            .parse::<usize>()
94            .map_err(|e| Error::Parse(format!("invalid max_terms: {}", e)))?,
95        None => DEFAULT_MAX_TERMS,
96    };
97    let bic_threshold = match args.get(1) {
98        Some(s) => s
99            .parse::<f64>()
100            .map_err(|e| Error::Parse(format!("invalid bic_threshold: {}", e)))?,
101        None => DEFAULT_BIC_THRESHOLD,
102    };
103    Ok((max_terms, bic_threshold))
104}
105
106fn active_observations(observations: &[Observation]) -> Vec<&Observation> {
107    observations.iter().filter(|o| !o.masked).collect()
108}
109
110fn compute_bic(n_obs: usize, n_terms: usize, sky_rms: f64) -> f64 {
111    let n = n_obs as f64;
112    let k = n_terms as f64;
113    let weighted_rss = sky_rms * sky_rms * n;
114    n * libm::log(weighted_rss / n) + k * libm::log(n)
115}
116
117fn try_fit(
118    observations: &[&Observation],
119    term_names: &[String],
120    latitude: f64,
121) -> Result<FitResult> {
122    let terms: Vec<_> = term_names
123        .iter()
124        .map(|n| create_term(n))
125        .collect::<Result<Vec<_>>>()?;
126    let fixed = vec![false; terms.len()];
127    let coeffs = vec![0.0; terms.len()];
128    fit_model(observations, &terms, &fixed, &coeffs, latitude)
129}
130
131fn run_physical_stage(
132    observations: &[&Observation],
133    active: &mut Vec<String>,
134    mut current_bic: f64,
135    threshold: f64,
136    latitude: f64,
137    log: &mut Vec<StageEntry>,
138) -> Result<f64> {
139    let n_obs = observations.len();
140    for &candidate in PHYSICAL_CANDIDATES {
141        if active.len() >= DEFAULT_MAX_TERMS {
142            break;
143        }
144        let mut trial = active.clone();
145        trial.push(candidate.to_string());
146        let fit = match try_fit(observations, &trial, latitude) {
147            Ok(f) => f,
148            Err(_) => continue,
149        };
150        let trial_bic = compute_bic(n_obs, trial.len(), fit.sky_rms);
151        let delta = trial_bic - current_bic;
152        if delta < threshold {
153            log.push(StageEntry {
154                name: candidate.to_string(),
155                delta_bic: delta,
156                rms: fit.sky_rms,
157            });
158            active.push(candidate.to_string());
159            current_bic = trial_bic;
160        }
161    }
162    Ok(current_bic)
163}
164
165fn generate_harmonic_candidates() -> Vec<String> {
166    let results = ["H", "D", "X"];
167    let funcs = ["S", "C"];
168    let coords = ["H", "D"];
169    let mut candidates = Vec::with_capacity(96);
170    for r in &results {
171        for f in &funcs {
172            for c in &coords {
173                for n in 1..=8u8 {
174                    let suffix = if n == 1 { String::new() } else { n.to_string() };
175                    candidates.push(format!("H{}{}{}{}", r, f, c, suffix));
176                }
177            }
178        }
179    }
180    candidates
181}
182
183fn run_harmonic_stage(
184    observations: &[&Observation],
185    active: &mut Vec<String>,
186    mut current_bic: f64,
187    threshold: f64,
188    max_terms: usize,
189    latitude: f64,
190    log: &mut Vec<StageEntry>,
191) -> Result<f64> {
192    let all_candidates = generate_harmonic_candidates();
193    let n_obs = observations.len();
194
195    loop {
196        if active.len() >= max_terms {
197            break;
198        }
199        let candidates: Vec<&String> = all_candidates
200            .iter()
201            .filter(|c| !active.contains(c))
202            .collect();
203        if candidates.is_empty() {
204            break;
205        }
206
207        let best = find_best_harmonic(observations, active, &candidates, n_obs, latitude);
208        match best {
209            Some((name, bic, rms)) => {
210                let delta = bic - current_bic;
211                if delta < threshold {
212                    log.push(StageEntry {
213                        name: name.clone(),
214                        delta_bic: delta,
215                        rms,
216                    });
217                    active.push(name);
218                    current_bic = bic;
219                } else {
220                    break;
221                }
222            }
223            None => break,
224        }
225    }
226    Ok(current_bic)
227}
228
229fn find_best_harmonic(
230    observations: &[&Observation],
231    active: &[String],
232    candidates: &[&String],
233    n_obs: usize,
234    latitude: f64,
235) -> Option<(String, f64, f64)> {
236    candidates
237        .par_iter()
238        .filter_map(|candidate| {
239            let mut trial = active.to_vec();
240            trial.push((*candidate).clone());
241            let fit = try_fit(observations, &trial, latitude).ok()?;
242            let bic = compute_bic(n_obs, trial.len(), fit.sky_rms);
243            Some(((*candidate).clone(), bic, fit.sky_rms))
244        })
245        .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
246}
247
248fn prune_terms(
249    observations: &[&Observation],
250    active: &mut Vec<String>,
251    latitude: f64,
252) -> Result<Vec<String>> {
253    let fit = try_fit(observations, active, latitude)?;
254    let mut pruned = Vec::new();
255    let base_set: Vec<String> = BASE_TERMS.iter().map(|s| s.to_string()).collect();
256
257    let to_remove: Vec<String> = active
258        .iter()
259        .enumerate()
260        .filter(|(i, name)| {
261            if base_set.contains(name) {
262                return false;
263            }
264            let sigma = fit.sigma[*i];
265            if sigma == 0.0 {
266                return false;
267            }
268            (fit.coefficients[*i] / sigma).abs() < MIN_SIGNIFICANCE
269        })
270        .map(|(_, name)| name.clone())
271        .collect();
272
273    for name in &to_remove {
274        active.retain(|n| n != name);
275        pruned.push(name.clone());
276    }
277    Ok(pruned)
278}
279
280fn load_into_session(session: &mut Session, active: &[String], result: &FitResult) -> Result<()> {
281    session.model.remove_all();
282    for name in active {
283        session.model.add_term(name)?;
284    }
285    session.model.set_coefficients(&result.coefficients)?;
286    session.last_fit = Some(result.clone());
287    Ok(())
288}
289
290fn append_base_report(report: &mut String, terms: &[String], bic: f64, rms: f64) {
291    report.push_str(&format!(
292        "Base: {} (BIC={:.1}, RMS={:.2}\")\n",
293        terms.join(" "),
294        bic,
295        rms,
296    ));
297}
298
299fn append_final_report(report: &mut String, terms: &[String], fit: &FitResult) {
300    report.push_str(&format!(
301        "\nFinal model: {} terms, RMS={:.2}\"\n",
302        terms.len(),
303        fit.sky_rms,
304    ));
305    report.push_str("Terms: ");
306    report.push_str(&terms.join(" "));
307    report.push('\n');
308}