use std::borrow::Cow;
use std::time::Instant;
use nalgebra::{Matrix3, Rotation3, UnitQuaternion, Vector3};
use tracing::debug;
use crate::Centroid;
use super::combinations::BreadthFirstCombinations;
use super::pattern::{
compute_edge_ratios, compute_pattern_key, compute_pattern_key_hash, compute_sorted_edge_angles,
hash_to_index, NUM_EDGES, NUM_EDGE_RATIOS, PATTERN_SIZE,
};
use super::wcs_refine;
use super::{SolveConfig, SolveResult, SolveStatus, SolverDatabase};
const C_KM_S: f64 = 299_792.458;
fn aberration_correct(sv: &[f32; 3], beta: &[f64; 3]) -> [f32; 3] {
let sx = sv[0] as f64;
let sy = sv[1] as f64;
let sz = sv[2] as f64;
let dot = sx * beta[0] + sy * beta[1] + sz * beta[2];
let ax = sx + beta[0] - sx * dot;
let ay = sy + beta[1] - sy * dot;
let az = sz + beta[2] - sz * dot;
let norm = (ax * ax + ay * ay + az * az).sqrt();
[(ax / norm) as f32, (ay / norm) as f32, (az / norm) as f32]
}
impl SolverDatabase {
pub fn solve_from_centroids(
&self,
centroids: &[Centroid],
config: &SolveConfig,
) -> SolveResult {
let t0 = Instant::now();
let star_vecs: Cow<[[f32; 3]]> = match config.observer_velocity_km_s {
Some(v) => {
let beta = [v[0] / C_KM_S, v[1] / C_KM_S, v[2] / C_KM_S];
Cow::Owned(
self.star_vectors
.iter()
.map(|sv| aberration_correct(sv, &beta))
.collect(),
)
}
None => Cow::Borrowed(&self.star_vectors),
};
let cam = &config.camera_model;
let preprocessed: Vec<Centroid> = centroids
.iter()
.map(|c| {
let cx = c.x as f64 - cam.crpix[0];
let cy = c.y as f64 - cam.crpix[1];
let (ux, uy) = cam.distortion.undistort(cx, cy);
Centroid {
x: ux as f32,
y: uy as f32,
mass: c.mass,
cov: c.cov,
}
})
.collect();
let working_centroids: &[Centroid] = &preprocessed;
let fov_values = build_fov_sweep(
config.fov_estimate_rad,
config.fov_max_error_rad,
config.match_radius,
);
debug!(
"FOV sweep: {} values from {:.2}° to {:.2}°",
fov_values.len(),
fov_values
.iter()
.cloned()
.reduce(f32::min)
.unwrap_or(0.0)
.to_degrees(),
fov_values
.iter()
.cloned()
.reduce(f32::max)
.unwrap_or(0.0)
.to_degrees(),
);
let mut last_status = SolveStatus::NoMatch;
for &fov_try in &fov_values {
if let Some(t) = config.solve_timeout_ms {
if elapsed_ms(t0) > t as f32 {
return SolveResult::failure(SolveStatus::Timeout, elapsed_ms(t0));
}
}
debug!("Trying FOV = {:.3}°", fov_try.to_degrees());
let result = self.solve_at_fov(working_centroids, config, fov_try, &star_vecs, t0);
match result.status {
SolveStatus::MatchFound => return result,
SolveStatus::TooFew => return result,
s => last_status = s,
}
}
SolveResult::failure(last_status, elapsed_ms(t0))
}
fn solve_at_fov(
&self,
centroids: &[Centroid],
config: &SolveConfig,
fov_estimate: f32,
star_vectors: &[[f32; 3]],
t0: Instant,
) -> SolveResult {
let pixel_scale = if config.image_width > 0 {
fov_estimate / config.image_width as f32
} else {
0.0
};
let mut sorted_indices: Vec<usize> = (0..centroids.len()).collect();
sorted_indices.sort_by(|&a, &b| {
let ma = centroids[a].mass.unwrap_or(f32::MIN);
let mb = centroids[b].mass.unwrap_or(f32::MIN);
mb.partial_cmp(&ma).unwrap_or(std::cmp::Ordering::Equal)
});
let num_centroids = sorted_indices.len();
if num_centroids < PATTERN_SIZE {
return SolveResult::failure(SolveStatus::TooFew, elapsed_ms(t0));
}
let centroid_vectors: Vec<[f32; 3]> = sorted_indices
.iter()
.map(|&i| {
let x = centroids[i].x * pixel_scale;
let y = centroids[i].y * pixel_scale;
let z = 1.0f32;
let norm = (x * x + y * y + z * z).sqrt();
[x / norm, y / norm, z / norm]
})
.collect();
let mut flipped_vectors: Option<Vec<[f32; 3]>> = None;
let verification_stars = self.props.verification_stars_per_fov;
let separation = separation_for_density(fov_estimate, verification_stars);
let cos_sep = separation.cos();
let mut keep_for_patterns = vec![false; num_centroids];
for i in 0..num_centroids {
let vi = ¢roid_vectors[i];
let mut occupied = false;
for j in 0..i {
if keep_for_patterns[j] {
let vj = ¢roid_vectors[j];
let dot = vi[0] * vj[0] + vi[1] * vj[1] + vi[2] * vj[2];
if dot > cos_sep {
occupied = true;
break;
}
}
}
if !occupied {
keep_for_patterns[i] = true;
}
}
let pattern_centroid_inds: Vec<usize> = (0..num_centroids)
.filter(|&i| keep_for_patterns[i])
.collect();
let num_pattern_centroids = pattern_centroid_inds.len();
debug!(
"Centroids: {} total, {} for patterns after cluster busting",
num_centroids, num_pattern_centroids
);
if num_pattern_centroids < PATTERN_SIZE {
return SolveResult::failure(SolveStatus::TooFew, elapsed_ms(t0));
}
let match_centroid_count = num_centroids.min(verification_stars as usize);
let p_bins = self.props.pattern_bins;
let p_max_err = config
.match_max_error
.unwrap_or(self.props.pattern_max_error)
.max(self.props.pattern_max_error);
let match_threshold = config.match_threshold / self.props.num_patterns as f64;
let timeout_ms = config.solve_timeout_ms;
debug!(
"Checking up to C({},{}) = {} image patterns",
num_pattern_centroids,
PATTERN_SIZE,
n_choose_k(num_pattern_centroids, PATTERN_SIZE)
);
let mut status = SolveStatus::NoMatch;
let mut pattern_key_list: Vec<(u32, [u32; NUM_EDGE_RATIOS])> = Vec::new();
for image_pattern_local in
BreadthFirstCombinations::<PATTERN_SIZE>::new(&pattern_centroid_inds)
{
if let Some(t) = timeout_ms {
if elapsed_ms(t0) > t as f32 {
debug!("Timeout after {:.1}ms", elapsed_ms(t0));
status = SolveStatus::Timeout;
break;
}
}
let image_vecs: [[f32; 3]; 4] = [
centroid_vectors[image_pattern_local[0]],
centroid_vectors[image_pattern_local[1]],
centroid_vectors[image_pattern_local[2]],
centroid_vectors[image_pattern_local[3]],
];
let edge_angles = compute_sorted_edge_angles(&image_vecs);
let image_largest_edge = edge_angles[NUM_EDGES - 1];
let image_ratios = compute_edge_ratios(&edge_angles);
let ratio_min: [f32; NUM_EDGE_RATIOS] =
std::array::from_fn(|i| image_ratios[i] - p_max_err);
let ratio_max: [f32; NUM_EDGE_RATIOS] =
std::array::from_fn(|i| image_ratios[i] + p_max_err);
let image_key = compute_pattern_key(&image_ratios, p_bins);
let key_min: [u32; NUM_EDGE_RATIOS] =
std::array::from_fn(|i| (ratio_min[i] * p_bins as f32).max(0.0) as u32);
let key_max: [u32; NUM_EDGE_RATIOS] =
std::array::from_fn(|i| (ratio_max[i] * p_bins as f32).min(p_bins as f32) as u32);
pattern_key_list.clear();
enumerate_key_range(&key_min, &key_max, &image_key, &mut pattern_key_list);
pattern_key_list.sort_unstable_by_key(|&(dist, _)| dist);
let table_len = self.pattern_catalog.len() as u64;
for &(_, ref pkey) in &pattern_key_list {
let pkey_hash = compute_pattern_key_hash(pkey, p_bins);
let hidx = hash_to_index(pkey_hash, table_len);
let key_hash16 = (pkey_hash & 0xFFFF) as u16;
for c in 0u64.. {
let tidx = ((hidx.wrapping_add(c.wrapping_mul(c))) % table_len) as usize;
let entry = &self.pattern_catalog[tidx];
if entry.is_empty() {
break; }
if entry.key_hash != key_hash16 {
continue;
}
let cat_largest = entry.largest_edge;
if let Some(fov_err) = config.fov_max_error_rad {
let implied_fov = cat_largest / image_largest_edge * fov_estimate;
if (implied_fov - fov_estimate).abs() > fov_err {
continue;
}
}
let cat_pat = entry.star_indices;
let cat_vecs: [[f32; 3]; 4] = [
star_vectors[cat_pat[0] as usize],
star_vectors[cat_pat[1] as usize],
star_vectors[cat_pat[2] as usize],
star_vectors[cat_pat[3] as usize],
];
let cat_edges = compute_sorted_edge_angles(&cat_vecs);
let cat_largest_edge = cat_edges[NUM_EDGES - 1];
let cat_ratios = compute_edge_ratios(&cat_edges);
let ratios_ok = (0..NUM_EDGE_RATIOS)
.all(|i| cat_ratios[i] > ratio_min[i] && cat_ratios[i] < ratio_max[i]);
if !ratios_ok {
continue;
}
let fov = cat_largest_edge / image_largest_edge * fov_estimate;
let mut img_order: [usize; 4] = [0, 1, 2, 3];
sort_by_centroid_distance_inline(&mut img_order, &image_vecs);
let matched_img: [[f32; 3]; 4] =
std::array::from_fn(|i| image_vecs[img_order[i]]);
let matched_cat: [[f32; 3]; 4] = std::array::from_fn(|i| cat_vecs[i]);
let mut rotation_matrix = find_rotation_matrix(&matched_img, &matched_cat);
let parity_flip;
let working_vectors: &[[f32; 3]];
if rotation_matrix.determinant() < 0.0 {
parity_flip = true;
let matched_img_flip: [[f32; 3]; 4] = std::array::from_fn(|i| {
let orig = image_vecs[img_order[i]];
[-orig[0], orig[1], orig[2]]
});
rotation_matrix = find_rotation_matrix(&matched_img_flip, &matched_cat);
if rotation_matrix.determinant() < 0.0 {
continue; }
let fv = flipped_vectors.get_or_insert_with(|| {
centroid_vectors
.iter()
.map(|v| [-v[0], v[1], v[2]])
.collect()
});
working_vectors = fv;
} else {
parity_flip = false;
working_vectors = ¢roid_vectors;
}
let fov_diagonal = fov * 1.42; let match_radius_rad = config.match_radius * fov;
let image_center_icrs =
rotation_matrix.transpose() * Vector3::new(0.0, 0.0, 1.0);
let nearby_inds = self
.star_catalog
.query_indices_from_uvec(image_center_icrs, fov_diagonal / 2.0);
let mut nearby_cam_positions: Vec<(usize, f32, f32)> = Vec::new();
for &cat_idx in &nearby_inds {
let sv = &star_vectors[cat_idx];
let icrs_v = Vector3::new(sv[0], sv[1], sv[2]);
let cam_v = rotation_matrix * icrs_v;
if cam_v.z > 0.0 {
let cx = cam_v.x / cam_v.z; let cy = cam_v.y / cam_v.z;
nearby_cam_positions.push((cat_idx, cx, cy));
}
}
nearby_cam_positions.truncate(2 * match_centroid_count);
let num_nearby = nearby_cam_positions.len();
let current_matches = find_centroid_matches(
&working_vectors[..match_centroid_count],
&nearby_cam_positions,
match_radius_rad,
);
let current_num_matches = current_matches.len();
let prob_single = num_nearby as f64 * (config.match_radius as f64).powi(2);
let prob_mismatch = binomial_cdf(
(match_centroid_count as i64 - (current_num_matches as i64 - 2)).max(0)
as u32,
match_centroid_count as u32,
1.0 - prob_single.min(1.0),
);
if prob_mismatch >= match_threshold {
continue;
}
debug!(
"MATCH: {} matches, prob={:.2e}, fov={:.3}°",
current_num_matches,
prob_mismatch * self.props.num_patterns as f64,
fov.to_degrees()
);
let parity_sign: f64 = if parity_flip { -1.0 } else { 1.0 };
let centroids_px: Vec<(f64, f64)> = sorted_indices
.iter()
.map(|&i| {
let px = parity_sign * centroids[i].x as f64;
let py = centroids[i].y as f64;
(px, py)
})
.collect();
let ps_refine = fov as f64 / config.image_width as f64;
let wcs_result = wcs_refine::wcs_refine(
&rotation_matrix,
¤t_matches,
¢roids_px,
star_vectors,
&self.star_catalog,
ps_refine,
parity_flip,
match_radius_rad,
match_centroid_count,
10,
);
if wcs_result.matches.len() < 4 {
continue;
}
let (refined_rotation, refined_fov, _) =
wcs_refine::wcs_to_rotation(
&wcs_result.cd_matrix,
wcs_result.crval_rad[0],
wcs_result.crval_rad[1],
config.image_width,
);
let ps = refined_fov / config.image_width.max(1) as f32;
let mut matched_cat_ids: Vec<u64> =
Vec::with_capacity(wcs_result.matches.len());
let mut matched_cent_inds: Vec<usize> =
Vec::with_capacity(wcs_result.matches.len());
let mut angular_residuals: Vec<f32> =
Vec::with_capacity(wcs_result.matches.len());
for &(cent_local_idx, cat_star_idx) in &wcs_result.matches {
matched_cat_ids.push(self.star_catalog_ids[cat_star_idx]);
matched_cent_inds.push(sorted_indices[cent_local_idx]);
let (px, py) = centroids_px[cent_local_idx];
let ix = px as f32 * ps;
let iy = py as f32 * ps;
let iz = 1.0f32;
let norm = (ix * ix + iy * iy + iz * iz).sqrt();
let img_v = refined_rotation.transpose()
* Vector3::new(ix / norm, iy / norm, iz / norm);
let sv = &star_vectors[cat_star_idx];
let cat_v = Vector3::new(sv[0], sv[1], sv[2]);
let cross = img_v.cross(&cat_v);
let ang = cross.norm().atan2(img_v.dot(&cat_v));
angular_residuals.push(ang);
}
angular_residuals
.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let rmse = if angular_residuals.is_empty() {
0.0
} else {
(angular_residuals.iter().map(|r| r * r).sum::<f32>()
/ angular_residuals.len() as f32)
.sqrt()
};
let p90e = if angular_residuals.is_empty() {
0.0
} else {
angular_residuals
[(0.9 * (angular_residuals.len() - 1) as f32) as usize]
};
let max_err = angular_residuals.last().copied().unwrap_or(0.0);
let rot3 = Rotation3::from_matrix_unchecked(refined_rotation);
let quat = UnitQuaternion::from_rotation_matrix(&rot3);
let mut result_cam = config.camera_model.clone();
let refined_f = (config.image_width as f64 / 2.0)
/ (refined_fov as f64 / 2.0).tan();
result_cam.focal_length_px = refined_f;
result_cam.parity_flip = parity_flip;
return SolveResult {
qicrs2cam: Some(quat),
fov_rad: Some(refined_fov),
num_matches: Some(wcs_result.matches.len() as u32),
rmse_rad: Some(rmse),
p90e_rad: Some(p90e),
max_err_rad: Some(max_err),
prob: Some(prob_mismatch * self.props.num_patterns as f64),
solve_time_ms: elapsed_ms(t0),
status: SolveStatus::MatchFound,
parity_flip,
matched_catalog_ids: matched_cat_ids,
matched_centroid_indices: matched_cent_inds,
image_width: config.image_width,
image_height: config.image_height,
cd_matrix: Some(wcs_result.cd_matrix),
crval_rad: Some(wcs_result.crval_rad),
camera_model: Some(result_cam),
theta_rad: Some(wcs_result.theta_rad),
};
}
}
}
SolveResult::failure(status, elapsed_ms(t0))
}
}
fn build_fov_sweep(fov_estimate: f32, fov_max_error: Option<f32>, match_radius: f32) -> Vec<f32> {
let mut values = vec![fov_estimate];
if let Some(max_error) = fov_max_error {
if max_error > 0.0 {
let step = (2.0 * match_radius * fov_estimate).max(0.001_f32.to_radians());
let mut offset = step;
while offset <= max_error {
values.push(fov_estimate + offset);
if fov_estimate - offset > 0.0 {
values.push(fov_estimate - offset);
}
offset += step;
}
}
}
values
}
fn elapsed_ms(t0: Instant) -> f32 {
t0.elapsed().as_secs_f32() * 1000.0
}
fn separation_for_density(fov_rad: f32, stars_per_fov: u32) -> f32 {
(fov_rad / 2.0) * (std::f32::consts::PI / stars_per_fov as f32).sqrt()
}
fn n_choose_k(n: usize, k: usize) -> usize {
if k > n {
return 0;
}
let mut result = 1usize;
for i in 0..k {
result = result * (n - i) / (i + 1);
}
result
}
fn enumerate_key_range(
key_min: &[u32; NUM_EDGE_RATIOS],
key_max: &[u32; NUM_EDGE_RATIOS],
center: &[u32; NUM_EDGE_RATIOS],
out: &mut Vec<(u32, [u32; NUM_EDGE_RATIOS])>,
) {
let mut current = [0u32; NUM_EDGE_RATIOS];
enumerate_key_range_recursive(key_min, key_max, center, 0, &mut current, out);
}
fn enumerate_key_range_recursive(
key_min: &[u32; NUM_EDGE_RATIOS],
key_max: &[u32; NUM_EDGE_RATIOS],
center: &[u32; NUM_EDGE_RATIOS],
dim: usize,
current: &mut [u32; NUM_EDGE_RATIOS],
out: &mut Vec<(u32, [u32; NUM_EDGE_RATIOS])>,
) {
if dim == NUM_EDGE_RATIOS {
let dist_sq: u32 = (0..NUM_EDGE_RATIOS)
.map(|i| {
let d = current[i] as i32 - center[i] as i32;
(d * d) as u32
})
.sum();
out.push((dist_sq, *current));
return;
}
for v in key_min[dim]..=key_max[dim] {
current[dim] = v;
enumerate_key_range_recursive(key_min, key_max, center, dim + 1, current, out);
}
}
fn sort_by_centroid_distance_inline(order: &mut [usize; 4], vectors: &[[f32; 3]; 4]) {
let mut cx = 0.0f32;
let mut cy = 0.0f32;
let mut cz = 0.0f32;
for v in vectors.iter() {
cx += v[0];
cy += v[1];
cz += v[2];
}
cx /= 4.0;
cy /= 4.0;
cz /= 4.0;
order.sort_by(|&a, &b| {
let va = &vectors[a];
let vb = &vectors[b];
let da = (va[0] - cx).powi(2) + (va[1] - cy).powi(2) + (va[2] - cz).powi(2);
let db = (vb[0] - cx).powi(2) + (vb[1] - cy).powi(2) + (vb[2] - cz).powi(2);
da.partial_cmp(&db).unwrap_or(std::cmp::Ordering::Equal)
});
}
fn find_rotation_matrix<const N: usize>(
image_vectors: &[[f32; 3]; N],
catalog_vectors: &[[f32; 3]; N],
) -> Matrix3<f32> {
let mut h = nalgebra::Matrix3::<f64>::zeros();
for i in 0..N {
let img = nalgebra::Vector3::<f64>::new(
image_vectors[i][0] as f64,
image_vectors[i][1] as f64,
image_vectors[i][2] as f64,
);
let cat = nalgebra::Vector3::<f64>::new(
catalog_vectors[i][0] as f64,
catalog_vectors[i][1] as f64,
catalog_vectors[i][2] as f64,
);
h += img * cat.transpose();
}
let svd = h.svd(true, true);
let u = svd.u.unwrap();
let v_t = svd.v_t.unwrap();
let r64 = u * v_t;
r64.cast::<f32>()
}
fn find_centroid_matches(
centroid_vectors: &[[f32; 3]],
catalog_positions: &[(usize, f32, f32)], match_radius: f32,
) -> Vec<(usize, usize)> {
let centroid_xy: Vec<(f32, f32)> = centroid_vectors
.iter()
.map(|v| {
if v[2] > 0.0 {
(v[0] / v[2], v[1] / v[2])
} else {
(f32::MAX, f32::MAX)
}
})
.collect();
let r2 = match_radius * match_radius;
let mut candidates: Vec<(f32, usize, usize)> = Vec::new(); for (ci, &(cx, cy)) in centroid_xy.iter().enumerate() {
for (pi, &(_cat_idx, px, py)) in catalog_positions.iter().enumerate() {
let dx = cx - px;
let dy = cy - py;
let d2 = dx * dx + dy * dy;
if d2 < r2 {
candidates.push((d2, ci, pi));
}
}
}
candidates.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
let mut used_centroids = vec![false; centroid_vectors.len()];
let mut used_catalog = vec![false; catalog_positions.len()];
let mut matches = Vec::new();
for &(_, ci, pi) in &candidates {
if !used_centroids[ci] && !used_catalog[pi] {
used_centroids[ci] = true;
used_catalog[pi] = true;
matches.push((ci, catalog_positions[pi].0));
}
}
matches
}
fn binomial_cdf(k: u32, n: u32, p: f64) -> f64 {
if k >= n {
return 1.0;
}
if p <= 0.0 {
return 1.0;
}
if p >= 1.0 {
return if k >= n { 1.0 } else { 0.0 };
}
let q = 1.0 - p;
let mut cdf = 0.0;
let mut log_term = n as f64 * q.ln(); cdf += log_term.exp();
for i in 1..=k as u64 {
log_term += ((n as u64 - i + 1) as f64).ln() - (i as f64).ln() + p.ln() - q.ln();
cdf += log_term.exp();
}
cdf.min(1.0)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_aberration_correct_shift_direction() {
let star = [1.0f32, 0.0, 0.0];
let beta = [0.0, 30.0 / C_KM_S, 0.0];
let apparent = aberration_correct(&star, &beta);
let norm = (apparent[0] as f64 * apparent[0] as f64
+ apparent[1] as f64 * apparent[1] as f64
+ apparent[2] as f64 * apparent[2] as f64)
.sqrt();
assert!((norm - 1.0).abs() < 1e-6, "output not unit length: {norm}");
assert!(apparent[1] > 0.0, "expected positive Y shift, got {}", apparent[1]);
let shift_rad = (apparent[1] as f64).atan2(apparent[0] as f64);
let expected = 30.0 / C_KM_S; assert!(
(shift_rad - expected).abs() < 1e-6,
"shift {shift_rad:.2e} rad, expected ~{expected:.2e} rad"
);
}
#[test]
fn test_aberration_correct_zero_velocity() {
let s = 1.0f32 / 3.0f32.sqrt();
let star = [s, s, s];
let beta = [0.0, 0.0, 0.0];
let apparent = aberration_correct(&star, &beta);
for i in 0..3 {
assert!(
(apparent[i] - star[i]).abs() < 1e-6,
"component {i} changed: {} -> {}",
star[i],
apparent[i]
);
}
}
#[test]
fn test_aberration_correct_parallel_velocity() {
let star = [1.0f32, 0.0, 0.0];
let beta = [30.0 / C_KM_S, 0.0, 0.0];
let apparent = aberration_correct(&star, &beta);
assert!(apparent[1].abs() < 1e-7, "Y not zero: {}", apparent[1]);
assert!(apparent[2].abs() < 1e-7, "Z not zero: {}", apparent[2]);
assert!((apparent[0] - 1.0).abs() < 1e-6);
}
}