use std::path::{Path, PathBuf};
use std::time::Duration;
use anyhow::{Context, Result, bail};
use clap::{Args, Parser, Subcommand, ValueEnum};
use indicatif::{ProgressBar, ProgressStyle};
use ndarray::Array2;
use rayon::prelude::*;
use tracing::{debug, info, warn};
use convolve_rs::{
beam::Beam,
common_beam::{common_beam, fits_in_beam},
cube_io::{self, CubeMeta, CubeMode},
fits_io::{output_path, read_fits, write_fits},
smooth::smooth,
};
#[derive(Parser)]
#[command(
name = "convolvers",
about = "Convolve FITS images/cubes to a common beam",
version
)]
struct Cli {
#[command(subcommand)]
command: Commands,
}
#[derive(Subcommand)]
enum Commands {
#[command(name = "2d")]
TwoD(TwoDArgs),
#[command(name = "3d")]
ThreeD(ThreeDArgs),
}
fn main() -> Result<()> {
let cli = Cli::parse();
match cli.command {
Commands::TwoD(args) => cmd_2d(args),
Commands::ThreeD(args) => cmd_3d(args),
}
}
#[derive(Args, Debug, Clone)]
struct SharedArgs {
#[arg(short, long, default_value = "sm")]
suffix: String,
#[arg(short, long)]
prefix: Option<String>,
#[arg(short, long)]
outdir: Option<PathBuf>,
#[arg(long)]
bmaj: Option<f64>,
#[arg(long)]
bmin: Option<f64>,
#[arg(long)]
bpa: Option<f64>,
#[arg(long)]
circularise: bool,
#[arg(short, long)]
cutoff: Option<f64>,
#[arg(short, long)]
dryrun: bool,
#[arg(long, default_value_t = 1e-4)]
tolerance: f64,
#[arg(long, default_value_t = 200)]
nsamps: usize,
#[arg(long, default_value_t = 5e-4)]
epsilon: f64,
#[arg(short, long, action = clap::ArgAction::Count)]
verbose: u8,
}
#[derive(Parser, Debug)]
struct TwoDArgs {
#[arg(required = true, num_args = 1..)]
infile: Vec<PathBuf>,
#[arg(long)]
listfile: bool,
#[command(flatten)]
shared: SharedArgs,
#[arg(long)]
log: Option<PathBuf>,
}
struct BeamLogEntry2D {
filename: PathBuf,
old_beam: Beam,
new_beam: Beam,
conv_beam: Beam,
}
fn cmd_2d(args: TwoDArgs) -> Result<()> {
init_logging(args.shared.verbose);
let files = collect_files(&args.infile, args.listfile)?;
let target_beam = parse_target_beam(&args.shared)?;
let sp = spinner(format!(
"Reading beam parameters from {} file(s)…",
files.len()
));
let all_beams: Vec<Beam> = files
.iter()
.map(|f| {
let data = read_fits(f).with_context(|| format!("reading {}", f.display()))?;
if let Some(cutoff) = args.shared.cutoff
&& data.beam.major_arcsec() > cutoff
{
sp.suspend(|| {
warn!(
"{}: BMAJ={:.1}\" > cutoff={:.1}\" — will be blanked",
f.display(),
data.beam.major_arcsec(),
cutoff
)
});
}
Ok(data.beam)
})
.collect::<Result<Vec<_>>>()?;
sp.finish_and_clear();
let mut common = match target_beam {
Some(b) => {
if !fits_in_beam(&all_beams, &b) {
bail!("target beam is too small — some images cannot reach it");
}
b
}
None => {
let valid: Vec<Beam> = all_beams
.iter()
.filter(|b| {
b.is_finite()
&& !b.is_zero()
&& args.shared.cutoff.is_none_or(|c| b.major_arcsec() <= c)
})
.cloned()
.collect();
anyhow::ensure!(!valid.is_empty(), "all beams are flagged or invalid");
let sp = spinner("Solving for the common beam…");
let cb = common_beam(
&valid,
args.shared.tolerance,
args.shared.nsamps,
args.shared.epsilon,
)
.context("could not find common beam")?;
sp.finish_and_clear();
cb
}
};
common = apply_beam_rounding(common, args.shared.circularise)?;
info!("Common beam: {common}");
if args.shared.dryrun {
info!("Dry run — no files written.");
return Ok(());
}
let pb = progress_bar(files.len() as u64);
let results: Vec<BeamLogEntry2D> = files
.par_iter()
.zip(all_beams.par_iter())
.map(|(file, old_beam)| {
pb.suspend(|| debug!("Reading {}", file.display()));
let data = read_fits(file).with_context(|| format!("reading {}", file.display()))?;
let out = output_path(
file,
Some(&args.shared.suffix),
args.shared.prefix.as_deref(),
args.shared.outdir.as_deref(),
);
let conv_beam = common.deconvolve_or_zero(old_beam);
pb.suspend(|| {
debug!(
"{}: current {old_beam} | target {common} | kernel {conv_beam}",
file.display()
)
});
let smoothed = smooth(
&data.image,
old_beam,
&common,
data.dx_deg,
data.dy_deg,
args.shared.cutoff,
data.unit,
)
.with_context(|| format!("smoothing {}", file.display()))?;
pb.suspend(|| debug!("Writing {}", out.display()));
write_fits(&smoothed, &out, file, &common, data.is_4d)
.with_context(|| format!("writing {}", out.display()))?;
pb.suspend(|| info!("{} → {}", file.display(), out.display()));
pb.inc(1);
Ok(BeamLogEntry2D {
filename: out,
old_beam: *old_beam,
new_beam: common,
conv_beam,
})
})
.collect::<Result<Vec<_>>>()?;
pb.finish_with_message("done");
if let Some(log_path) = &args.log {
use std::fmt::Write as _;
let mut out = String::from(
"# FileName OldBMAJ[deg] OldBMIN[deg] OldBPA[deg] TargetBMAJ[deg] TargetBMIN[deg] TargetBPA[deg] ConvBMAJ[deg] ConvBMIN[deg] ConvBPA[deg]\n",
);
for e in &results {
writeln!(
out,
"{} {} {} {} {} {} {} {} {} {}",
e.filename.display(),
e.old_beam.major_deg,
e.old_beam.minor_deg,
e.old_beam.pa_deg,
e.new_beam.major_deg,
e.new_beam.minor_deg,
e.new_beam.pa_deg,
e.conv_beam.major_deg,
e.conv_beam.minor_deg,
e.conv_beam.pa_deg,
)?;
}
std::fs::write(log_path, out)?;
info!("Beamlog written to {}", log_path.display());
}
Ok(())
}
#[derive(ValueEnum, Clone, Debug, PartialEq, Eq)]
enum ModeArg {
Natural,
Total,
}
#[derive(Parser, Debug)]
struct ThreeDArgs {
#[arg(required = true, num_args = 1..)]
infile: Vec<PathBuf>,
#[arg(long)]
listfile: bool,
#[command(flatten)]
shared: SharedArgs,
#[arg(long, default_value = "natural", value_enum)]
mode: ModeArg,
}
fn cmd_3d(args: ThreeDArgs) -> Result<()> {
init_logging(args.shared.verbose);
let files = collect_files(&args.infile, args.listfile)?;
let sp = spinner(format!("Reading metadata from {} cube(s)…", files.len()));
let metas: Vec<CubeMeta> = files
.iter()
.map(|f| {
sp.suspend(|| debug!("Reading metadata + per-channel beams from {}", f.display()));
let m = cube_io::read_cube_meta(f)
.with_context(|| format!("reading metadata from {}", f.display()))?;
sp.suspend(|| {
debug!(
"{}: {}×{} px, {} channels, {} Stokes",
f.display(),
m.nx,
m.ny,
m.nfreq,
m.nstokes
)
});
Ok(m)
})
.collect::<Result<_>>()?;
sp.finish_and_clear();
info!("Read metadata from {} cube(s)", files.len());
let nfreq = metas[0].nfreq;
for (f, m) in files.iter().zip(metas.iter()) {
anyhow::ensure!(
m.nfreq == nfreq,
"{}: expected {} channels, got {}",
f.display(),
nfreq,
m.nfreq
);
if m.nstokes > 1 {
warn!(
"{}: NAXIS4={} — only Stokes 0 will be convolved",
f.display(),
m.nstokes
);
}
}
let target_beam = parse_target_beam(&args.shared)?;
let target_beams: Vec<Option<Beam>> = if let Some(b) = target_beam {
let all_valid: Vec<Beam> = metas
.iter()
.flat_map(|m| m.beams.iter())
.filter_map(|b| *b)
.filter(|b| b.is_finite() && !b.is_zero())
.collect();
if !fits_in_beam(&all_valid, &b) {
bail!("target beam is too small — some channels cannot reach it");
}
vec![Some(b); nfreq]
} else {
let mode = match args.mode {
ModeArg::Natural => CubeMode::Natural,
ModeArg::Total => CubeMode::Total,
};
let sp = spinner(match mode {
CubeMode::Natural => "Solving for per-channel common beams…".to_string(),
CubeMode::Total => "Solving for the common beam across all channels…".to_string(),
});
let beams = compute_target_beams(
&metas,
mode,
args.shared.cutoff,
args.shared.circularise,
args.shared.tolerance,
args.shared.nsamps,
args.shared.epsilon,
)?;
sp.finish_and_clear();
beams
};
let n_valid = target_beams.iter().filter(|b| b.is_some()).count();
let all_same = target_beams
.iter()
.filter_map(|b| *b)
.collect::<Vec<_>>()
.windows(2)
.all(|w| w[0] == w[1]);
match target_beams.iter().find_map(|b| *b) {
Some(b) if all_same => info!("Target beam (all channels): {b}"),
Some(b) => {
info!("Target beam varies per channel ({n_valid} valid channels); e.g. channel 0: {b}");
info!("Run with -v to log the current/target/kernel beam for every channel.");
}
None => {}
}
if args.shared.dryrun {
info!("Dry run — no files written.");
return Ok(());
}
let cube_mode = match args.mode {
ModeArg::Natural => CubeMode::Natural,
ModeArg::Total => CubeMode::Total,
};
let pb = progress_bar((files.len() * nfreq) as u64);
for (file, meta) in files.iter().zip(metas.iter()) {
let out = output_path(
file,
Some(&args.shared.suffix),
args.shared.prefix.as_deref(),
args.shared.outdir.as_deref(),
);
pb.suspend(|| info!("Initialising output cube {} …", out.display()));
pb.set_message("initialising");
cube_io::init_output_cube(file, &out, &target_beams, cube_mode, meta)
.with_context(|| format!("initialising output cube {}", out.display()))?;
pb.set_message("processing");
let cap = (rayon::current_num_threads() * 2).max(4);
let (tx, rx) = std::sync::mpsc::sync_channel::<(usize, Array2<f32>)>(cap);
let writer_out = out.clone();
let writer_meta = meta;
let result: Result<()> = std::thread::scope(|s| {
let writer_handle = s.spawn(move || -> Result<()> {
let mut writer = cube_io::CubeWriter::open(&writer_out)
.with_context(|| format!("opening output cube {}", writer_out.display()))?;
for (c, plane) in rx {
writer
.write_channel(c, &plane, writer_meta)
.with_context(|| {
format!("writing channel {c} to {}", writer_out.display())
})?;
}
Ok(())
});
let produce: Result<()> = (0..nfreq).into_par_iter().try_for_each(|c| {
let old_beam = match meta.beams[c] {
Some(b) => b,
None => {
pb.inc(1);
return Ok(());
}
};
let target = match target_beams[c] {
Some(b) => b,
None => {
pb.inc(1);
return Ok(());
}
};
let kernel = target.deconvolve_or_zero(&old_beam);
pb.suspend(|| {
debug!("Channel {c}: current {old_beam} | target {target} | kernel {kernel}")
});
let plane = if let Some(cutoff) = args.shared.cutoff
&& old_beam.major_arcsec() > cutoff
{
pb.suspend(|| {
warn!(
"Channel {c}: BMAJ={:.1}\" > cutoff — blanking",
old_beam.major_arcsec()
)
});
Array2::from_elem((meta.ny, meta.nx), f32::NAN)
} else {
let raw = cube_io::read_channel(file, c, meta)
.with_context(|| format!("reading channel {c} from {}", file.display()))?;
smooth(
&raw,
&old_beam,
&target,
meta.dx_deg,
meta.dy_deg,
args.shared.cutoff,
meta.unit,
)
.with_context(|| format!("smoothing channel {c}"))?
};
tx.send((c, plane))
.map_err(|_| anyhow::anyhow!("writer thread stopped before channel {c}"))?;
pb.inc(1);
Ok(())
});
drop(tx);
let writer_result = writer_handle
.join()
.map_err(|_| anyhow::anyhow!("writer thread panicked"))?;
writer_result.and(produce)
});
result?;
let beamlog = {
let dir = out.parent().unwrap_or(Path::new("."));
let stem = out.file_stem().unwrap_or_default();
dir.join(format!("beamlog.{}.txt", stem.to_string_lossy()))
};
cube_io::write_beamlog(&beamlog, &target_beams)
.with_context(|| format!("writing beamlog {}", beamlog.display()))?;
pb.suspend(|| debug!("Beamlog written to {}", beamlog.display()));
pb.suspend(|| info!("{} → {}", file.display(), out.display()));
}
pb.finish_with_message("done");
Ok(())
}
fn compute_target_beams(
metas: &[CubeMeta],
mode: CubeMode,
cutoff: Option<f64>,
circularise: bool,
tolerance: f64,
nsamps: usize,
epsilon: f64,
) -> Result<Vec<Option<Beam>>> {
let nfreq = metas[0].nfreq;
match mode {
CubeMode::Natural => (0..nfreq)
.map(|c| {
let valid: Vec<Beam> = metas
.iter()
.filter_map(|m| m.beams[c])
.filter(|b| b.is_finite() && !b.is_zero())
.filter(|b| cutoff.is_none_or(|cut| b.major_arcsec() <= cut))
.collect();
if valid.is_empty() {
return Ok(None);
}
let cb = common_beam(&valid, tolerance, nsamps, epsilon)
.with_context(|| format!("finding common beam for channel {c}"))?;
Ok(Some(apply_beam_rounding(cb, circularise)?))
})
.collect(),
CubeMode::Total => {
let valid: Vec<Beam> = metas
.iter()
.flat_map(|m| m.beams.iter())
.filter_map(|b| *b)
.filter(|b| b.is_finite() && !b.is_zero())
.filter(|b| cutoff.is_none_or(|cut| b.major_arcsec() <= cut))
.collect();
anyhow::ensure!(
!valid.is_empty(),
"no valid beams found across all cubes/channels"
);
let cb = common_beam(&valid, tolerance, nsamps, epsilon)
.context("finding total common beam")?;
let cb = apply_beam_rounding(cb, circularise)?;
Ok(vec![Some(cb); nfreq])
}
}
}
fn init_logging(verbose: u8) {
let level = match verbose {
0 => tracing::Level::INFO,
1 => tracing::Level::DEBUG,
_ => tracing::Level::TRACE,
};
tracing_subscriber::fmt()
.with_max_level(level)
.with_target(false)
.init();
}
fn collect_files(infile: &[PathBuf], listfile: bool) -> Result<Vec<PathBuf>> {
let files = if listfile {
anyhow::ensure!(infile.len() == 1, "only one listfile argument supported");
std::fs::read_to_string(&infile[0])?
.lines()
.map(|l| PathBuf::from(l.trim()))
.collect()
} else {
infile.to_vec()
};
anyhow::ensure!(!files.is_empty(), "no input files found");
Ok(files)
}
fn parse_target_beam(args: &SharedArgs) -> Result<Option<Beam>> {
match (args.bmaj, args.bmin, args.bpa) {
(None, None, None) => Ok(None),
(Some(bmaj), Some(bmin), Some(bpa)) => Ok(Some(
Beam::from_arcsec(bmaj, bmin, bpa).context("invalid target beam")?,
)),
_ => bail!("--bmaj, --bmin, and --bpa must all be specified together"),
}
}
fn apply_beam_rounding(b: Beam, circularise: bool) -> Result<Beam> {
let b = Beam::from_arcsec(
ceil_to(b.major_arcsec(), 1),
ceil_to(b.minor_arcsec(), 1),
round_up(b.pa_deg, 2),
)
.context("rounding common beam")?;
if circularise {
Beam::from_arcsec(b.major_arcsec(), b.major_arcsec(), 0.0).context("circularising beam")
} else {
Ok(b)
}
}
fn progress_bar(total: u64) -> ProgressBar {
let pb = ProgressBar::new(total);
pb.set_style(
ProgressStyle::default_bar()
.template("{spinner:.green} [{elapsed}] [{bar:40.cyan/blue}] {pos}/{len} {msg}")
.unwrap()
.progress_chars("=>-"),
);
pb.enable_steady_tick(Duration::from_millis(100));
pb
}
fn spinner(msg: impl Into<String>) -> ProgressBar {
let pb = ProgressBar::new_spinner();
pb.set_style(
ProgressStyle::default_spinner()
.template("{spinner:.green} [{elapsed}] {msg}")
.unwrap(),
);
pb.set_message(msg.into());
pb.enable_steady_tick(Duration::from_millis(100));
pb
}
fn ceil_to(x: f64, precision: i32) -> f64 {
let factor = 10_f64.powi(precision);
(x * factor).ceil() / factor
}
fn round_up(x: f64, decimals: i32) -> f64 {
let factor = 10_f64.powi(decimals);
(x * factor).ceil() / factor
}