1use crate::beam::{Beam, BeamError, deconvolve_deg};
8use thiserror::Error;
9
10const DEG2RAD: f64 = std::f64::consts::PI / 180.0;
11
12#[derive(Debug, Error)]
13pub enum CommonBeamError {
14 #[error("no beams provided")]
15 NoBeans,
16 #[error("all beams are invalid or flagged")]
17 AllFlagged,
18 #[error("Khachiyan algorithm did not converge after {0} iterations")]
19 NoConvergence(usize),
20 #[error("common beam does not deconvolve all inputs: {0}")]
21 DeconvFailed(String),
22 #[error("beam error: {0}")]
23 Beam(#[from] BeamError),
24}
25
26pub fn common_beam(
50 beams: &[Beam],
51 tolerance: f64,
52 nsamps: usize,
53 epsilon: f64,
54) -> Result<Beam, CommonBeamError> {
55 if beams.is_empty() {
56 return Err(CommonBeamError::NoBeans);
57 }
58 if beams.len() == 1 {
59 return Ok(beams[0]);
60 }
61
62 let largest = largest_beam(beams);
64 if fits_in_beam(beams, &largest) {
65 return Ok(largest);
66 }
67
68 if beams.len() == 2
69 && let Ok(b) = find_commonbeam_between(&beams[0], &beams[1])
70 {
71 return Ok(b);
72 }
73
74 common_manybeams_mve(beams, tolerance, nsamps, epsilon)
75}
76
77pub fn find_commonbeam_between(beam1: &Beam, beam2: &Beam) -> Result<Beam, CommonBeamError> {
81 if beam1.approx_eq(beam2) {
82 return Ok(*beam1);
83 }
84
85 let (large_beam, small_beam) = if beam1.area_sr() >= beam2.area_sr() {
86 (beam1, beam2)
87 } else {
88 (beam2, beam1)
89 };
90
91 let deconv = large_beam.deconvolve_or_zero(small_beam);
93 if deconv.is_finite() {
94 return Ok(*large_beam);
95 }
96
97 let large_major = large_beam.major_arcsec();
98 let large_minor = large_beam.minor_arcsec();
99 let small_major = small_beam.major_arcsec();
100
101 if small_beam.is_circular(1e-6) {
103 let beam = Beam::from_arcsec(large_major, small_major, large_beam.pa_deg)?;
104 return Ok(beam);
105 }
106
107 let pa_diff_rad = ((small_beam.pa_deg - large_beam.pa_deg) * DEG2RAD
109 + std::f64::consts::FRAC_PI_2
110 + std::f64::consts::PI)
111 .rem_euclid(std::f64::consts::PI)
112 - std::f64::consts::FRAC_PI_2;
113
114 if (pa_diff_rad.abs() - std::f64::consts::FRAC_PI_2).abs() < 1e-9 {
116 let (major, minor) = if large_major >= small_major {
117 (large_major, small_major)
118 } else {
119 (small_major, large_major)
120 };
121 let pa = if large_major >= small_major {
122 large_beam.pa_deg
123 } else {
124 small_beam.pa_deg
125 };
126 return Beam::from_arcsec(major, minor, pa).map_err(Into::into);
127 }
128
129 let major_comb = (large_major * small_major).sqrt();
131 let p = major_comb / large_major;
132 let q = major_comb / large_minor;
133
134 let (trans_maj_sc, _trans_min_sc, trans_pa_sc) =
135 transform_ellipse_arcsec(small_major, small_beam.minor_arcsec(), pa_diff_rad, p, q);
136
137 let trans_min_sc = major_comb;
139
140 let (trans_maj_unsc, trans_min_unsc, trans_pa_unsc) =
142 transform_ellipse_arcsec(trans_maj_sc, trans_min_sc, trans_pa_sc, 1.0 / p, 1.0 / q);
143
144 let final_pa_deg = trans_pa_unsc.to_degrees() + large_beam.pa_deg;
145
146 let eps = 100.0 * f64::EPSILON;
147 let beam = Beam::from_arcsec(trans_maj_unsc + eps, trans_min_unsc + eps, final_pa_deg)?;
148
149 Ok(beam)
150}
151
152fn transform_ellipse_arcsec(
155 major: f64,
156 minor: f64,
157 pa: f64,
158 x_scale: f64,
159 y_scale: f64,
160) -> (f64, f64, f64) {
161 let cospa = pa.cos();
162 let sinpa = pa.sin();
163 let cos2pa = cospa * cospa;
164 let sin2pa = sinpa * sinpa;
165 let major2 = major * major;
166 let minor2 = minor * minor;
167
168 let a = cos2pa / major2 + sin2pa / minor2;
169 let b = -2.0 * cospa * sinpa * (1.0 / major2 - 1.0 / minor2);
170 let c = sin2pa / major2 + cos2pa / minor2;
171
172 let x2 = x_scale * x_scale;
173 let y2 = y_scale * y_scale;
174
175 let r = a / x2;
176 let s = b * b / (4.0 * x2 * y2);
177 let t = c / y2;
178
179 let udiff = r - t;
180 let f1 = udiff * udiff + 4.0 * s;
181 let f2 = f1.sqrt() * udiff.abs();
182
183 let j1 = (f2 + f1) / (2.0 * f1);
184 let j2 = (f1 - f2) / (2.0 * f1);
185
186 let k1 = (j1 * (r + t) - t) / (2.0 * j1 - 1.0);
187 let k2 = (j2 * (r + t) - t) / (2.0 * j2 - 1.0);
188
189 let c1 = 1.0 / k1.sqrt();
190 let c2 = 1.0 / k2.sqrt();
191
192 let pa_sign = if pa >= 0.0 { 1.0 } else { -1.0 };
193
194 if (c1 - c2).abs() < f64::EPSILON {
195 (1.0 / c1, 1.0 / c1, 0.0)
196 } else if c1 > c2 {
197 (c1, c2, pa_sign * j1.sqrt().acos())
198 } else {
199 (c2, c1, pa_sign * j2.sqrt().acos())
200 }
201}
202
203pub fn common_manybeams_mve(
208 beams: &[Beam],
209 tolerance: f64,
210 nsamps: usize,
211 epsilon: f64,
212) -> Result<Beam, CommonBeamError> {
213 let max_iter = 10;
214 let max_epsilon = 1e-3_f64;
215 let mut eps = epsilon;
216
217 for step in 0..=max_iter {
218 let all_pts = collect_ellipse_points(beams, nsamps, eps);
219 let hull_pts = convex_hull_2d(&all_pts);
220
221 let (radii, rotation) = min_vol_ellipse(&hull_pts, tolerance)?;
222
223 let pa = (-rotation[0][0]).atan2(rotation[1][0]);
226 let pa = if pa == -std::f64::consts::PI || pa == std::f64::consts::PI {
227 0.0
228 } else {
229 pa
230 };
231
232 let r0 = radii[0];
233 let r1 = radii[1];
234 let (major_deg, minor_deg) = if r0 >= r1 { (r0, r1) } else { (r1, r0) };
235
236 let com_beam = Beam::new(major_deg, minor_deg, pa.to_degrees())?;
237
238 if fits_in_beam(beams, &com_beam) {
239 return Ok(com_beam);
240 }
241
242 if step == max_iter {
243 return Err(CommonBeamError::DeconvFailed(format!(
244 "epsilon reached {eps:.2e} without finding valid solution"
245 )));
246 }
247
248 eps += (step as f64 + 1.0) * (max_epsilon - eps) / max_iter as f64;
249 }
250
251 unreachable!()
252}
253
254fn collect_ellipse_points(beams: &[Beam], nsamps: usize, epsilon: f64) -> Vec<[f64; 2]> {
256 let mut pts = Vec::with_capacity(beams.len() * nsamps);
257 for beam in beams {
258 let bpa = beam.pa_deg * DEG2RAD;
259 let major = beam.major_deg * (1.0 + epsilon);
260 let minor = beam.minor_deg * (1.0 + epsilon);
261 for k in 0..nsamps {
262 let phi = 2.0 * std::f64::consts::PI * k as f64 / nsamps as f64;
263 let x = major * phi.cos();
264 let y = minor * phi.sin();
265 let xr = x * bpa.cos() - y * bpa.sin();
266 let yr = x * bpa.sin() + y * bpa.cos();
267 pts.push([xr, yr]);
268 }
269 }
270 pts
271}
272
273fn convex_hull_2d(pts: &[[f64; 2]]) -> Vec<[f64; 2]> {
275 if pts.len() <= 3 {
276 return pts.to_vec();
277 }
278
279 let pivot = pts
281 .iter()
282 .enumerate()
283 .min_by(|(_, a), (_, b)| {
284 a[1].partial_cmp(&b[1])
285 .unwrap()
286 .then(a[0].partial_cmp(&b[0]).unwrap())
287 })
288 .map(|(i, _)| i)
289 .unwrap();
290
291 let pivot_pt = pts[pivot];
292
293 let mut sorted: Vec<[f64; 2]> = pts.iter().filter(|&&p| p != pivot_pt).cloned().collect();
294
295 sorted.sort_by(|a, b| {
296 let angle_a = (a[1] - pivot_pt[1]).atan2(a[0] - pivot_pt[0]);
297 let angle_b = (b[1] - pivot_pt[1]).atan2(b[0] - pivot_pt[0]);
298 angle_a.partial_cmp(&angle_b).unwrap()
299 });
300
301 let mut hull: Vec<[f64; 2]> = vec![pivot_pt];
302 for &p in &sorted {
303 while hull.len() > 1 {
304 let n = hull.len();
305 let cross = cross2d(hull[n - 2], hull[n - 1], p);
306 if cross <= 0.0 {
307 hull.pop();
308 } else {
309 break;
310 }
311 }
312 hull.push(p);
313 }
314 hull
315}
316
317fn cross2d(o: [f64; 2], a: [f64; 2], b: [f64; 2]) -> f64 {
318 (a[0] - o[0]) * (b[1] - o[1]) - (a[1] - o[1]) * (b[0] - o[0])
319}
320
321fn min_vol_ellipse(
328 pts: &[[f64; 2]],
329 tolerance: f64,
330) -> Result<([f64; 2], [[f64; 2]; 2]), CommonBeamError> {
331 let n = pts.len();
332 let d = 2_usize;
333
334 let q: Vec<[f64; 3]> = pts.iter().map(|p| [p[0], p[1], 1.0]).collect(); let mut u = vec![1.0 / n as f64; n];
338
339 let max_iter = 100_000;
340 let mut err = 1.0_f64;
341 let mut iter = 0;
342
343 while err > tolerance {
344 let v = matmul_qt_diag_q(&q, &u); let v_inv = mat3_inv(v)?;
347
348 let m: Vec<f64> = q.iter().map(|qi| quadratic_form_3(qi, &v_inv)).collect();
350
351 let j = m
352 .iter()
353 .enumerate()
354 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
355 .map(|(i, _)| i)
356 .unwrap();
357 let maximum = m[j];
358
359 let step = (maximum - d as f64 - 1.0) / ((d as f64 + 1.0) * (maximum - 1.0));
360 let new_u: Vec<f64> = u.iter().map(|&ui| (1.0 - step) * ui).collect();
361 err = new_u
362 .iter()
363 .zip(u.iter())
364 .map(|(a, b)| (a - b).powi(2))
365 .sum::<f64>()
366 .sqrt();
367 u = new_u;
368 u[j] += step;
369
370 iter += 1;
371 if iter >= max_iter {
372 return Err(CommonBeamError::NoConvergence(max_iter));
373 }
374 }
375
376 let center = [
378 pts.iter()
379 .zip(u.iter())
380 .map(|(p, &ui)| p[0] * ui)
381 .sum::<f64>(),
382 pts.iter()
383 .zip(u.iter())
384 .map(|(p, &ui)| p[1] * ui)
385 .sum::<f64>(),
386 ];
387
388 let ptdp = matmul_pt_diag_p(pts, &u); let cc = [
391 [center[0] * center[0], center[0] * center[1]],
392 [center[1] * center[0], center[1] * center[1]],
393 ];
394 let inner = [
395 [ptdp[0][0] - cc[0][0], ptdp[0][1] - cc[0][1]],
396 [ptdp[1][0] - cc[1][0], ptdp[1][1] - cc[1][1]],
397 ];
398 let a = mat2_scale(mat2_inv(inner)?, 1.0 / d as f64);
399
400 let (eigenvalues, eigenvectors) = symmetric_2x2_eig(a);
403
404 let radii = [
405 (1.0 / eigenvalues[0].abs().sqrt()) * (1.0 + tolerance),
406 (1.0 / eigenvalues[1].abs().sqrt()) * (1.0 + tolerance),
407 ];
408
409 Ok((radii, eigenvectors))
410}
411
412fn matmul_qt_diag_q(q: &[[f64; 3]], u: &[f64]) -> [[f64; 3]; 3] {
416 let mut v = [[0.0_f64; 3]; 3];
417 for (qi, &ui) in q.iter().zip(u.iter()) {
418 for r in 0..3 {
419 for c in 0..3 {
420 v[r][c] += qi[r] * ui * qi[c];
421 }
422 }
423 }
424 v
425}
426
427fn matmul_pt_diag_p(p: &[[f64; 2]], u: &[f64]) -> [[f64; 2]; 2] {
429 let mut m = [[0.0_f64; 2]; 2];
430 for (pi, &ui) in p.iter().zip(u.iter()) {
431 for r in 0..2 {
432 for c in 0..2 {
433 m[r][c] += pi[r] * ui * pi[c];
434 }
435 }
436 }
437 m
438}
439
440fn quadratic_form_3(x: &[f64; 3], m: &[[f64; 3]; 3]) -> f64 {
442 let mut acc = 0.0_f64;
443 for r in 0..3 {
444 for c in 0..3 {
445 acc += x[r] * m[r][c] * x[c];
446 }
447 }
448 acc
449}
450
451fn mat3_inv(m: [[f64; 3]; 3]) -> Result<[[f64; 3]; 3], CommonBeamError> {
453 let det = m[0][0] * (m[1][1] * m[2][2] - m[1][2] * m[2][1])
454 - m[0][1] * (m[1][0] * m[2][2] - m[1][2] * m[2][0])
455 + m[0][2] * (m[1][0] * m[2][1] - m[1][1] * m[2][0]);
456
457 if det.abs() < f64::EPSILON {
458 return Err(CommonBeamError::DeconvFailed("singular 3x3 matrix".into()));
459 }
460 let inv_det = 1.0 / det;
461
462 Ok([
463 [
464 inv_det * (m[1][1] * m[2][2] - m[1][2] * m[2][1]),
465 inv_det * (m[0][2] * m[2][1] - m[0][1] * m[2][2]),
466 inv_det * (m[0][1] * m[1][2] - m[0][2] * m[1][1]),
467 ],
468 [
469 inv_det * (m[1][2] * m[2][0] - m[1][0] * m[2][2]),
470 inv_det * (m[0][0] * m[2][2] - m[0][2] * m[2][0]),
471 inv_det * (m[0][2] * m[1][0] - m[0][0] * m[1][2]),
472 ],
473 [
474 inv_det * (m[1][0] * m[2][1] - m[1][1] * m[2][0]),
475 inv_det * (m[0][1] * m[2][0] - m[0][0] * m[2][1]),
476 inv_det * (m[0][0] * m[1][1] - m[0][1] * m[1][0]),
477 ],
478 ])
479}
480
481fn mat2_inv(m: [[f64; 2]; 2]) -> Result<[[f64; 2]; 2], CommonBeamError> {
483 let det = m[0][0] * m[1][1] - m[0][1] * m[1][0];
484 if det.abs() < f64::EPSILON {
485 return Err(CommonBeamError::DeconvFailed("singular 2x2 matrix".into()));
486 }
487 let inv_det = 1.0 / det;
488 Ok([
489 [inv_det * m[1][1], -inv_det * m[0][1]],
490 [-inv_det * m[1][0], inv_det * m[0][0]],
491 ])
492}
493
494fn mat2_scale(m: [[f64; 2]; 2], s: f64) -> [[f64; 2]; 2] {
495 [[m[0][0] * s, m[0][1] * s], [m[1][0] * s, m[1][1] * s]]
496}
497
498fn symmetric_2x2_eig(m: [[f64; 2]; 2]) -> ([f64; 2], [[f64; 2]; 2]) {
501 let a = m[0][0];
502 let b = m[0][1]; let c = m[1][1];
504
505 let trace = a + c;
506 let det = a * c - b * b;
507 let disc = ((trace / 2.0).powi(2) - det).max(0.0).sqrt();
508
509 let lam1 = trace / 2.0 + disc;
510 let lam2 = trace / 2.0 - disc;
511
512 let (v1, v2) = if b.abs() > f64::EPSILON {
514 let v1 = normalise([lam1 - c, b]);
515 let v2 = normalise([lam2 - c, b]);
516 (v1, v2)
517 } else if a >= c {
518 ([1.0, 0.0], [0.0, 1.0])
519 } else {
520 ([0.0, 1.0], [1.0, 0.0])
521 };
522
523 let rotation = [[v1[0], v2[0]], [v1[1], v2[1]]];
528
529 ([lam1, lam2], rotation)
530}
531
532fn normalise(v: [f64; 2]) -> [f64; 2] {
533 let len = (v[0] * v[0] + v[1] * v[1]).sqrt();
534 if len < f64::EPSILON {
535 return v;
536 }
537 [v[0] / len, v[1] / len]
538}
539
540pub fn largest_beam(beams: &[Beam]) -> Beam {
543 beams
544 .iter()
545 .max_by(|a, b| a.area_sr().partial_cmp(&b.area_sr()).unwrap())
546 .copied()
547 .unwrap()
548}
549
550pub fn fits_in_beam(beams: &[Beam], large_beam: &Beam) -> bool {
552 beams.iter().all(|b| {
553 if b.approx_eq(large_beam) {
554 return true;
555 }
556 let result = deconvolve_deg(
557 large_beam.major_deg,
558 large_beam.minor_deg,
559 large_beam.pa_deg,
560 b.major_deg,
561 b.minor_deg,
562 b.pa_deg,
563 true,
564 );
565 match result {
566 Ok((maj, min, _)) => maj > 0.0 && min > 0.0,
567 Err(_) => false,
568 }
569 })
570}
571
572#[cfg(test)]
573mod tests {
574 use super::*;
575
576 #[test]
577 fn test_common_beam_identical() {
578 let b = Beam::new(10.0 / 3600.0, 8.0 / 3600.0, 30.0).unwrap();
579 let result = common_beam(&[b, b], 1e-4, 200, 5e-4).unwrap();
580 assert!(result.approx_eq(&b) || result.major_deg >= b.major_deg);
581 }
582
583 #[test]
584 fn test_common_beam_two_different() {
585 let b1 = Beam::new(10.0 / 3600.0, 8.0 / 3600.0, 30.0).unwrap();
586 let b2 = Beam::new(12.0 / 3600.0, 6.0 / 3600.0, 60.0).unwrap();
587 let result = common_beam(&[b1, b2], 1e-4, 200, 5e-4).unwrap();
588 assert!(
589 result.major_deg >= b1.major_deg.max(b2.major_deg) || fits_in_beam(&[b1, b2], &result)
590 );
591 }
592}