use apex_solver::JacobianMode;
use apex_solver::apex_camera_models::{BALPinholeCameraStrict, DistortionModel, PinholeParams};
use apex_solver::apex_io::{BalDataset, BalLoader};
use apex_solver::apex_manifolds::ManifoldType;
use apex_solver::apex_manifolds::se3::SE3;
use apex_solver::apex_manifolds::so3::SO3;
use apex_solver::core::loss_functions::HuberLoss;
use apex_solver::core::problem::Problem;
use apex_solver::factors::ProjectionFactor;
use apex_solver::init_logger;
use apex_solver::linalg::SchurVariant;
use apex_solver::optimizer::levenberg_marquardt::{LevenbergMarquardt, LevenbergMarquardtConfig};
use clap::{Parser, ValueEnum};
use nalgebra::{DVector, Matrix2xX, Vector2, Vector3};
use std::collections::HashMap;
use std::error::Error;
use std::path::PathBuf;
use std::time::Instant;
use tracing::info;
#[derive(Debug, Clone, Copy, ValueEnum, Default)]
enum SolverArg {
Explicit,
#[default]
Implicit,
}
impl From<SolverArg> for SchurVariant {
fn from(arg: SolverArg) -> Self {
match arg {
SolverArg::Explicit => SchurVariant::Sparse,
SolverArg::Implicit => SchurVariant::Iterative,
}
}
}
#[derive(Debug, Clone, Copy, ValueEnum, Default)]
enum OptimizationType {
BundleAdjustment,
#[default]
SelfCalibration,
OnlyPose,
OnlyLandmarks,
OnlyIntrinsics,
}
#[derive(Parser)]
#[command(name = "bundle_adjustment")]
#[command(about = "Bundle adjustment optimization for BAL datasets")]
struct Args {
#[arg(value_name = "FILE")]
file: PathBuf,
#[arg(long)]
cameras: Option<u32>,
#[arg(long)]
points: Option<u32>,
#[arg(short = 'n', long)]
num_points: Option<usize>,
#[arg(short = 's', long, value_enum, default_value = "implicit")]
solver: SolverArg,
#[arg(short = 't', long, value_enum, default_value = "self-calibration")]
optimization_type: OptimizationType,
#[arg(short, long)]
verbose: bool,
#[arg(long)]
#[cfg(feature = "visualization")]
with_visualizer: bool,
}
fn main() -> Result<(), Box<dyn Error>> {
let args = Args::parse();
init_logger();
info!("APEX-SOLVER BUNDLE ADJUSTMENT");
info!("");
let file_path = if args.file.exists() {
args.file.clone()
} else if let (Some(cameras), Some(points)) = (args.cameras, args.points) {
let dataset_name = args
.file
.parent()
.and_then(|p| p.file_name())
.and_then(|n| n.to_str())
.ok_or_else(|| {
format!(
"Cannot infer dataset name from path '{}'. \
Ensure the path has the form data/bundle_adjustment/<name>/problem-…txt",
args.file.display()
)
})?;
info!(
"File not found — downloading {dataset_name}/problem-{cameras}-{points} from registry …"
);
apex_solver::apex_io::ensure_ba_dataset(dataset_name, cameras, points)
.map_err(|e| format!("Auto-download failed: {e}"))?
} else {
return Err(format!(
"File not found: {}. Provide --cameras and --points to download it automatically.",
args.file.display()
)
.into());
};
info!("Loading BAL dataset: {}", file_path.display());
let start_load = Instant::now();
let dataset = BalLoader::load(file_path.to_string_lossy().as_ref())?;
let load_time = start_load.elapsed();
let num_points_to_use = args.num_points.unwrap_or(dataset.points.len());
let num_points_to_use = num_points_to_use.min(dataset.points.len());
info!("Dataset statistics:");
info!(" Cameras: {}", dataset.cameras.len());
info!(" Total points: {}", dataset.points.len());
info!(" Points to use: {}", num_points_to_use);
info!(" Observations: {}", dataset.observations.len());
info!(" Load time: {:?}", load_time);
info!("");
#[cfg(feature = "visualization")]
let with_visualizer = args.with_visualizer;
#[cfg(not(feature = "visualization"))]
let with_visualizer = false;
run_bundle_adjustment(
&dataset,
num_points_to_use,
args.solver.into(),
args.optimization_type,
args.verbose,
with_visualizer,
)
}
fn axis_angle_to_so3(axis_angle: &Vector3<f64>) -> SO3 {
let angle = axis_angle.norm();
if angle < 1e-10 {
SO3::identity()
} else {
let axis = axis_angle / angle;
SO3::from_axis_angle(&axis, angle)
}
}
#[cfg_attr(not(feature = "visualization"), allow(unused_variables))]
fn run_bundle_adjustment(
dataset: &BalDataset,
num_points: usize,
solver_variant: SchurVariant,
opt_type: OptimizationType,
verbose: bool,
with_visualizer: bool,
) -> Result<(), Box<dyn Error>> {
use apex_solver::factors::{
BundleAdjustment, OnlyIntrinsics, OnlyLandmarks, OnlyPose, SelfCalibration,
};
let mut problem = Problem::new(JacobianMode::Sparse);
let mut initial_values = HashMap::new();
info!(
"Adding {} cameras as SE3 poses + intrinsics...",
dataset.cameras.len()
);
for (i, cam) in dataset.cameras.iter().enumerate() {
let axis_angle = Vector3::new(cam.rotation.x, cam.rotation.y, cam.rotation.z);
let translation = Vector3::new(cam.translation.x, cam.translation.y, cam.translation.z);
let so3 = axis_angle_to_so3(&axis_angle);
let pose = SE3::from_translation_so3(translation, so3);
let pose_name = format!("pose_{:04}", i);
initial_values.insert(pose_name, (ManifoldType::SE3, DVector::from(pose)));
let intrinsics_name = format!("intr_{:04}", i);
let intrinsics_vec = DVector::from_vec(vec![cam.focal_length, cam.k1, cam.k2]);
initial_values.insert(intrinsics_name, (ManifoldType::RN, intrinsics_vec));
}
info!("Adding {} landmarks as RN(3) variables...", num_points);
for j in 0..num_points {
let point = &dataset.points[j];
let var_name = format!("pt_{:05}", j);
let point_vec =
DVector::from_vec(vec![point.position.x, point.position.y, point.position.z]);
initial_values.insert(var_name, (ManifoldType::RN, point_vec));
}
let valid_obs: Vec<_> = dataset
.observations
.iter()
.filter(|obs| obs.point_index < num_points)
.collect();
info!(
"Adding {} projection factors (optimization: {:?})...",
valid_obs.len(),
opt_type
);
match opt_type {
OptimizationType::SelfCalibration => {
add_factors::<SelfCalibration>(&mut problem, dataset, &valid_obs, true)?;
}
OptimizationType::BundleAdjustment => {
add_factors::<BundleAdjustment>(&mut problem, dataset, &valid_obs, false)?;
}
OptimizationType::OnlyPose => {
add_factors::<OnlyPose>(&mut problem, dataset, &valid_obs, false)?;
}
OptimizationType::OnlyLandmarks => {
add_factors::<OnlyLandmarks>(&mut problem, dataset, &valid_obs, false)?;
}
OptimizationType::OnlyIntrinsics => {
add_factors::<OnlyIntrinsics>(&mut problem, dataset, &valid_obs, true)?;
}
}
info!("Fixing first camera pose (all 6 DOF) for gauge freedom...");
for dof in 0..6 {
problem.fix_variable("pose_0000", dof);
}
let mut config = LevenbergMarquardtConfig::for_bundle_adjustment();
config.schur_variant = solver_variant;
info!("");
info!("Solver configuration:");
info!(" Solver variant: {:?}", solver_variant);
info!(" Optimization type: {:?}", opt_type);
info!(" Linear solver: {:?}", config.linear_solver_type);
info!(" Preconditioner: {:?}", config.schur_preconditioner);
let mut solver = LevenbergMarquardt::with_config(config);
#[cfg(feature = "visualization")]
if with_visualizer {
use apex_solver::observers::RerunObserver;
use apex_solver::observers::visualization::VisualizationConfig;
let config = VisualizationConfig::for_bundle_adjustment()
.with_camera_frustum_scale(0.1)
.with_initial_landmark_color([255, 255, 255])
.with_optimized_landmark_color([255, 255, 255]);
match RerunObserver::with_config(true, None, config) {
Ok(observer) => {
solver.add_observer(observer);
info!("Rerun visualization enabled for bundle adjustment");
}
Err(e) => tracing::warn!("Failed to create Rerun observer: {}", e),
}
}
let num_cameras = dataset.cameras.len();
let num_factors = valid_obs.len();
let pose_dof = num_cameras * 6;
let intrinsic_dof = num_cameras * 3;
let landmark_dof = num_points * 3;
let total_dof = pose_dof + intrinsic_dof + landmark_dof;
info!("");
info!("Diagnostics:");
info!(" Cameras: {}", num_cameras);
info!(" Number of factors (observations): {}", num_factors);
info!(" Pose DOF: {} (6 per camera)", pose_dof);
info!(" Intrinsic DOF: {} (3 per camera)", intrinsic_dof);
info!(" Landmark DOF: {}", landmark_dof);
info!(" Total DOF: {}", total_dof);
info!(
" DOF per observation: {:.2}",
total_dof as f64 / num_factors as f64
);
info!("");
info!("Starting optimization...");
let start = Instant::now();
let result = solver.optimize(&problem, &initial_values)?;
let elapsed = start.elapsed();
info!("");
info!("Optimization completed!");
info!("Status: {:?}", result.status);
info!("Iterations: {}", result.iterations);
info!("Time: {:.2} seconds", elapsed.as_secs_f64());
let num_obs = valid_obs.len() as f64;
let initial_rmse = (result.initial_cost / num_obs).sqrt();
let final_rmse = (result.final_cost / num_obs).sqrt();
info!("");
info!("Metrics:");
info!(" Initial cost: {:.6e}", result.initial_cost);
info!(" Final cost: {:.6e}", result.final_cost);
info!(" Initial RMSE: {:.3} pixels", initial_rmse);
info!(" Final RMSE: {:.3} pixels", final_rmse);
info!(
" Improvement: {:.2}%",
(result.initial_cost - result.final_cost) / result.initial_cost * 100.0
);
if verbose {
info!("");
info!(
" Per-iteration: {:.2}s",
elapsed.as_secs_f64() / result.iterations as f64
);
}
Ok(())
}
fn add_factors<OP>(
problem: &mut Problem,
dataset: &BalDataset,
valid_obs: &[&apex_solver::apex_io::BalObservation],
include_intrinsics: bool,
) -> Result<(), Box<dyn Error>>
where
OP: apex_solver::factors::projection_factor::OptimizationConfig + 'static,
{
for obs in valid_obs {
let cam = &dataset.cameras[obs.camera_index];
let camera = BALPinholeCameraStrict::new(
PinholeParams {
fx: cam.focal_length,
fy: cam.focal_length,
cx: 0.0,
cy: 0.0,
},
DistortionModel::Radial {
k1: cam.k1,
k2: cam.k2,
},
)?;
let observations = Matrix2xX::from_columns(&[Vector2::new(obs.x, obs.y)]);
let factor: ProjectionFactor<BALPinholeCameraStrict, OP> =
ProjectionFactor::new(observations, camera);
let pose_name = format!("pose_{:04}", obs.camera_index);
let pt_name = format!("pt_{:05}", obs.point_index);
let intr_name = format!("intr_{:04}", obs.camera_index);
let loss = match HuberLoss::new(1.0) {
Ok(l) => Box::new(l),
Err(_) => continue,
};
if include_intrinsics {
problem.add_residual_block(
&[&pose_name, &pt_name, &intr_name],
Box::new(factor),
Some(loss),
);
} else {
problem.add_residual_block(&[&pose_name, &pt_name], Box::new(factor), Some(loss));
}
}
Ok(())
}