use nalgebra::Matrix3;
use tracing::debug;
use crate::camera_model::CameraModel;
use crate::centroid::Centroid;
use crate::distortion::fit::{fit_polynomial_distortion, DistortionFitConfig};
use crate::solver::wcs_refine;
use crate::solver::{SolveResult, SolveStatus, SolverDatabase};
use super::fit::{
build_id_lookup, compute_corrected_rmse, fit_polynomial_sigma_clip, MatchedPoint,
};
use super::polynomial::{num_coeffs, PolynomialDistortion};
use super::Distortion;
#[derive(Debug, Clone)]
pub struct CalibrateConfig {
pub polynomial_order: u32,
pub max_iterations: u32,
pub sigma_clip: f64,
pub convergence_threshold_px: f64,
}
impl Default for CalibrateConfig {
fn default() -> Self {
Self {
polynomial_order: 4,
max_iterations: 20,
sigma_clip: 3.0,
convergence_threshold_px: 0.01,
}
}
}
#[derive(Debug, Clone, rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)]
pub struct CalibrateResult {
pub camera_model: CameraModel,
pub rmse_before_px: f64,
pub rmse_after_px: f64,
pub n_inliers: usize,
pub n_outliers: usize,
pub iterations: u32,
}
pub fn calibrate_camera(
solve_results: &[&SolveResult],
centroids: &[&[Centroid]],
database: &SolverDatabase,
image_width: u32,
image_height: u32,
config: &CalibrateConfig,
) -> CalibrateResult {
assert_eq!(
solve_results.len(),
centroids.len(),
"solve_results and centroids must have the same length"
);
assert!(
config.polynomial_order >= 2 && config.polynomial_order <= 6,
"polynomial order must be in [2, 6]"
);
let n_valid = solve_results
.iter()
.filter(|sr| sr.status == SolveStatus::MatchFound && sr.qicrs2cam.is_some())
.count();
if n_valid <= 1 {
single_image_calibrate(solve_results, centroids, database, image_width, image_height, config)
} else {
multi_image_calibrate(solve_results, centroids, database, image_width, image_height, config)
}
}
fn extract_crpix(distortion: Distortion) -> ([f64; 2], Distortion) {
match distortion {
Distortion::Polynomial(poly) => {
let crpix_x = -poly.ap_coeffs[0] * poly.scale;
let crpix_y = -poly.bp_coeffs[0] * poly.scale;
let mut a = poly.a_coeffs.clone();
let mut b = poly.b_coeffs.clone();
let mut ap = poly.ap_coeffs.clone();
let mut bp = poly.bp_coeffs.clone();
a[0] = 0.0;
b[0] = 0.0;
ap[0] = 0.0;
bp[0] = 0.0;
let new_poly =
PolynomialDistortion::new(poly.order, poly.scale, a, b, ap, bp);
([crpix_x, crpix_y], Distortion::Polynomial(new_poly))
}
other => ([0.0, 0.0], other),
}
}
fn single_image_calibrate(
solve_results: &[&SolveResult],
centroids: &[&[Centroid]],
database: &SolverDatabase,
image_width: u32,
image_height: u32,
config: &CalibrateConfig,
) -> CalibrateResult {
let fit_config = DistortionFitConfig {
sigma_clip: config.sigma_clip,
max_iterations: config.max_iterations,
stage2_threshold_px: Some(5.0),
};
let fit_result = fit_polynomial_distortion(
solve_results,
centroids,
database,
image_width,
config.polynomial_order,
&fit_config,
);
let fov_rad = solve_results
.iter()
.find_map(|sr| sr.fov_rad)
.unwrap_or(0.1);
let parity_flip = solve_results
.iter()
.find(|sr| sr.status == SolveStatus::MatchFound)
.map_or(false, |sr| sr.parity_flip);
let (crpix, distortion) = extract_crpix(fit_result.model);
let cam = CameraModel {
focal_length_px: image_width as f64 / fov_rad as f64,
image_width,
image_height,
crpix,
parity_flip,
distortion,
};
debug!(
"calibrate_camera (single): order {}, crpix=[{:.2}, {:.2}], RMSE {:.3} -> {:.3} px, {}/{} inliers",
config.polynomial_order,
crpix[0], crpix[1],
fit_result.rmse_before_px,
fit_result.rmse_after_px,
fit_result.n_inliers,
fit_result.n_inliers + fit_result.n_outliers,
);
CalibrateResult {
camera_model: cam,
rmse_before_px: fit_result.rmse_before_px,
rmse_after_px: fit_result.rmse_after_px,
n_inliers: fit_result.n_inliers,
n_outliers: fit_result.n_outliers,
iterations: fit_result.iterations,
}
}
fn multi_image_calibrate(
solve_results: &[&SolveResult],
centroids: &[&[Centroid]],
database: &SolverDatabase,
image_width: u32,
image_height: u32,
config: &CalibrateConfig,
) -> CalibrateResult {
let order = config.polynomial_order;
let scale = image_width as f64 / 2.0;
let id_to_idx = build_id_lookup(database);
let mut fovs: Vec<f32> = Vec::new();
let mut parity_flip = false;
for sr in solve_results.iter() {
if sr.status != SolveStatus::MatchFound {
continue;
}
if let Some(fov) = sr.fov_rad {
fovs.push(fov);
}
parity_flip = sr.parity_flip;
}
fovs.sort_by(|a, b| a.partial_cmp(b).unwrap());
let median_fov = fovs[fovs.len() / 2];
let global_pixel_scale = median_fov as f64 / image_width as f64;
let parity_sign: f64 = if parity_flip { -1.0 } else { 1.0 };
debug!(
"calibrate_camera (multi): {} valid images, median FOV={:.3} deg, parity={}",
fovs.len(),
median_fov.to_degrees(),
parity_flip,
);
let mut current_distortion = Distortion::None;
let mut last_rmse = f64::MAX;
let mut last_rmse_before = 0.0_f64;
let fit_config = DistortionFitConfig {
sigma_clip: config.sigma_clip,
max_iterations: config.max_iterations,
stage2_threshold_px: Some(5.0),
};
struct ImageData {
sr_idx: usize,
rotation: Matrix3<f32>,
fov_rad: f32,
}
let mut image_data: Vec<ImageData> = Vec::new();
for (idx, sr) in solve_results.iter().enumerate() {
if sr.status != SolveStatus::MatchFound {
continue;
}
let quat = match &sr.qicrs2cam {
Some(q) => q,
None => continue,
};
let fov = match sr.fov_rad {
Some(f) => f,
None => continue,
};
let rot: Matrix3<f32> = *quat.to_rotation_matrix().matrix();
image_data.push(ImageData {
sr_idx: idx,
rotation: rot,
fov_rad: fov,
});
}
let mut total_iterations = 0u32;
let mut final_mask = Vec::new();
let mut final_n_points = 0usize;
for outer in 0..3 {
struct RefinedImage {
sr_idx: usize,
matches: Vec<(usize, usize)>, crval_ra: f64,
crval_dec: f64,
cd_matrix: [[f64; 2]; 2],
}
let mut refined_images: Vec<RefinedImage> = Vec::new();
for img in &image_data {
let sr = solve_results[img.sr_idx];
let cents = centroids[img.sr_idx];
let per_image_ps = img.fov_rad as f64 / image_width as f64;
let centroids_px: Vec<(f64, f64)> = cents
.iter()
.map(|c| {
let cx = c.x as f64;
let cy = c.y as f64;
let (ux, uy) = current_distortion.undistort(cx, cy);
(parity_sign * ux, uy)
})
.collect();
let mut initial_matches: Vec<(usize, usize)> = Vec::new();
for (match_idx, &cat_id) in sr.matched_catalog_ids.iter().enumerate() {
let cent_idx = sr.matched_centroid_indices[match_idx];
if cent_idx >= cents.len() {
continue;
}
if let Some(&star_idx) = id_to_idx.get(&cat_id) {
initial_matches.push((cent_idx, star_idx));
}
}
if initial_matches.len() < 4 {
continue;
}
let match_radius_rad = 0.01 * img.fov_rad;
let wcs_result = wcs_refine::wcs_refine(
&img.rotation,
&initial_matches,
¢roids_px,
&database.star_vectors,
&database.star_catalog,
per_image_ps,
parity_flip,
match_radius_rad,
cents.len().min(500),
10,
);
if wcs_result.matches.len() < 4 {
debug!(
" multi-cal outer {}: image {} wcs_refine returned only {} matches, skipping",
outer, img.sr_idx, wcs_result.matches.len()
);
continue;
}
debug!(
" multi-cal outer {}: image {} refined: {} matches, RMSE={:.2}\"",
outer,
img.sr_idx,
wcs_result.matches.len(),
wcs_result.rmse_rad.to_degrees() * 3600.0,
);
refined_images.push(RefinedImage {
sr_idx: img.sr_idx,
matches: wcs_result.matches,
crval_ra: wcs_result.crval_rad[0],
crval_dec: wcs_result.crval_rad[1],
cd_matrix: wcs_result.cd_matrix,
});
}
if refined_images.is_empty() {
debug!(" multi-cal outer {}: no refined images, aborting", outer);
break;
}
let mut all_points: Vec<MatchedPoint> = Vec::new();
for ref_img in &refined_images {
let cents = centroids[ref_img.sr_idx];
let (rot, _fov, _parity) = wcs_refine::wcs_to_rotation(
&ref_img.cd_matrix,
ref_img.crval_ra,
ref_img.crval_dec,
image_width,
);
for &(cent_idx, cat_idx) in &ref_img.matches {
let sv = &database.star_vectors[cat_idx];
let icrs_v = nalgebra::Vector3::new(sv[0], sv[1], sv[2]);
let cam_v = rot * icrs_v;
if cam_v.z <= 0.0 {
continue;
}
let x_ideal = parity_sign * (cam_v.x as f64) / (cam_v.z as f64) / global_pixel_scale;
let y_ideal = (cam_v.y as f64) / (cam_v.z as f64) / global_pixel_scale;
let x_obs = cents[cent_idx].x as f64;
let y_obs = cents[cent_idx].y as f64;
all_points.push(MatchedPoint {
x_obs,
y_obs,
x_ideal,
y_ideal,
});
}
}
if all_points.len() < num_coeffs(order) {
debug!(
" multi-cal outer {}: too few points ({}) for order-{} fit",
outer,
all_points.len(),
order,
);
break;
}
debug!(
" multi-cal outer {}: {} total matched points from {} images",
outer,
all_points.len(),
refined_images.len(),
);
let fit = fit_polynomial_sigma_clip(&all_points, order, scale, &fit_config);
let n_inliers = fit.mask.iter().filter(|&&m| m).count();
let model = PolynomialDistortion::new(
order,
scale,
fit.a_coeffs,
fit.b_coeffs,
fit.ap_coeffs,
fit.bp_coeffs,
);
let dist = Distortion::Polynomial(model);
let rmse_after = compute_corrected_rmse(&all_points, &fit.mask, &dist);
let rmse_before = compute_corrected_rmse(&all_points, &fit.mask, &Distortion::None);
debug!(
" multi-cal outer {}: polynomial fit: {}/{} inliers, RMSE {:.3} -> {:.3} px",
outer, n_inliers, all_points.len(), rmse_before, rmse_after,
);
total_iterations += fit.iterations;
final_mask = fit.mask;
final_n_points = all_points.len();
current_distortion = dist;
last_rmse_before = rmse_before;
let rmse_change = (last_rmse - rmse_after).abs();
let rmse_frac_change = if last_rmse > 1e-12 {
rmse_change / last_rmse
} else {
0.0
};
last_rmse = rmse_after;
if rmse_frac_change < 0.01 || rmse_change < config.convergence_threshold_px {
debug!(
" multi-cal: converged at outer iteration {} (RMSE change={:.4} px, {:.2}%)",
outer, rmse_change, rmse_frac_change * 100.0,
);
break;
}
}
let (crpix, distortion) = extract_crpix(current_distortion);
let cam = CameraModel {
focal_length_px: image_width as f64 / median_fov as f64,
image_width,
image_height,
crpix,
parity_flip,
distortion,
};
let n_inliers = final_mask.iter().filter(|&&m| m).count();
debug!(
"calibrate_camera (multi): order {}, crpix=[{:.2}, {:.2}], RMSE {:.3} -> {:.3} px, {}/{} inliers",
order, crpix[0], crpix[1], last_rmse_before, last_rmse, n_inliers, final_n_points,
);
CalibrateResult {
camera_model: cam,
rmse_before_px: last_rmse_before,
rmse_after_px: last_rmse,
n_inliers,
n_outliers: final_n_points - n_inliers,
iterations: total_iterations,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_calibrate_config_defaults() {
let cfg = CalibrateConfig::default();
assert_eq!(cfg.polynomial_order, 4);
assert_eq!(cfg.max_iterations, 20);
assert!((cfg.sigma_clip - 3.0).abs() < 1e-12);
assert!((cfg.convergence_threshold_px - 0.01).abs() < 1e-12);
}
}