assist-rs 0.1.0

Domain layer for ASSIST + REBOUND solar-system propagation: orbits, ephemerides, observatories, STM, batch + parallel
Documentation
//! Validation test: compare assist-rs propagation against Python ASSIST reference data.
//!
//! The reference data is generated by `validation/generate_reference.py` using
//! Python REBOUND 4.6.0 + ASSIST (same C libraries). This test propagates the
//! same 28 sample orbits from adam_core helpers and asserts matching results.
//!
//! Requires:
//!   ASSIST_PLANETS_PATH  — path to de440.bsp
//!   ASSIST_ASTEROIDS_PATH — path to sb441-n16.bsp
//!   validation/reference_data.json — generated by the Python script
//!
//! Run:
//!   cargo test --test validation_test -- --test-threads=1

use std::path::PathBuf;

use rayon::prelude::*;

/// Reference data structures matching the JSON schema.
#[derive(serde::Deserialize)]
struct ReferenceData {
    metadata: Metadata,
    orbits: Vec<OrbitEntry>,
    #[serde(default)]
    nongrav_orbits: Vec<OrbitEntry>,
}

#[derive(serde::Deserialize)]
struct Metadata {
    rebound_version: String,
    assist_version: String,
    propagation_days: f64,
    n_epochs: usize,
    #[allow(dead_code)]
    jd_ref: f64,
    #[allow(dead_code)]
    cos_eps: f64,
    #[allow(dead_code)]
    sin_eps: f64,
    /// Non-gravitational parameters [A1, A2, A3] used for `nongrav_orbits`.
    /// Optional for backward compatibility with older reference files.
    #[serde(default)]
    nongrav_params: Option<[f64; 3]>,
}

#[derive(serde::Deserialize, Clone)]
struct OrbitEntry {
    object_id: String,
    epoch_mjd: f64,
    initial_state: [f64; 6],
    propagated: Vec<PropagatedEntry>,
}

#[derive(serde::Deserialize, Clone)]
struct PropagatedEntry {
    epoch: f64,
    state: [f64; 6],
}

fn ephem_paths() -> Option<(PathBuf, PathBuf)> {
    let planets = std::env::var("ASSIST_PLANETS_PATH").ok()?;
    let asteroids = std::env::var("ASSIST_ASTEROIDS_PATH").ok()?;
    Some((PathBuf::from(planets), PathBuf::from(asteroids)))
}

fn load_reference_data() -> Option<ReferenceData> {
    let path = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
        .join("validation")
        .join("reference_data.json");
    if !path.exists() {
        return None;
    }
    let data = std::fs::read_to_string(&path).ok()?;
    serde_json::from_str(&data).ok()
}

/// Propagate a single orbit and compare against reference.
/// Returns (object_id, max_position_error_au, max_velocity_error).
fn validate_orbit(
    data: &assist_rs::AssistData,
    entry: &OrbitEntry,
    non_grav: Option<[f64; 3]>,
) -> (String, f64, f64) {
    let orbit = match non_grav {
        Some([a1, a2, a3]) => assist_rs::Orbit::with_non_grav(
            entry.initial_state,
            entry.epoch_mjd,
            assist_rs::NonGravParams::new(a1, a2, a3),
        ),
        None => assist_rs::Orbit::new(entry.initial_state, entry.epoch_mjd),
    };
    let target_epochs: Vec<f64> = entry.propagated.iter().map(|p| p.epoch).collect();

    let results = assist_rs::assist_propagate_single(
        data,
        &orbit,
        &target_epochs,
        false,
        &assist_rs::IntegratorConfig::default(),
    )
    .unwrap_or_else(|e| panic!("Propagation failed for {}: {e}", entry.object_id));

    assert_eq!(
        results.len(),
        entry.propagated.len(),
        "Epoch count mismatch for {}",
        entry.object_id
    );

    let mut max_pos_err = 0.0f64;
    let mut max_vel_err = 0.0f64;

    for (rust_result, py_ref) in results.iter().zip(entry.propagated.iter()) {
        // Position error (AU)
        let pos_err = ((rust_result.state[0] - py_ref.state[0]).powi(2)
            + (rust_result.state[1] - py_ref.state[1]).powi(2)
            + (rust_result.state[2] - py_ref.state[2]).powi(2))
        .sqrt();

        // Velocity error (AU/day)
        let vel_err = ((rust_result.state[3] - py_ref.state[3]).powi(2)
            + (rust_result.state[4] - py_ref.state[4]).powi(2)
            + (rust_result.state[5] - py_ref.state[5]).powi(2))
        .sqrt();

        max_pos_err = max_pos_err.max(pos_err);
        max_vel_err = max_vel_err.max(vel_err);
    }

    (entry.object_id.clone(), max_pos_err, max_vel_err)
}

/// Run a batch of orbits through validation and assert the aggregate
/// errors stay under the tolerance. Prints a summary, fails loudly on
/// out-of-tolerance entries.
fn run_validation_batch(
    label: &str,
    data: &assist_rs::AssistData,
    entries: &[OrbitEntry],
    non_grav: Option<[f64; 3]>,
) {
    // Tolerance: since we use ASSIST 1.2.0 and the reference uses 1.1.9a2,
    // there may be small differences from C library changes.
    // Position tolerance: 1e-10 AU ≈ 15 meters
    // Velocity tolerance: 1e-12 AU/day ≈ 0.002 mm/s
    //
    // If both sides use the exact same ASSIST version, these should be
    // bit-identical (< 1e-15). With version differences, we allow a
    // generous tolerance.
    let pos_tol_au = 1e-10;
    let vel_tol_au_day = 1e-12;

    let results: Vec<(String, f64, f64)> = entries
        .par_iter()
        .map(|entry| validate_orbit(data, entry, non_grav))
        .collect();

    let mut n_pass = 0;
    let mut n_fail = 0;
    let mut worst_pos_err = 0.0f64;
    let mut worst_vel_err = 0.0f64;
    let mut worst_pos_name = String::new();
    let mut worst_vel_name = String::new();

    for (name, pos_err, vel_err) in &results {
        if *pos_err > worst_pos_err {
            worst_pos_name = name.clone();
            worst_pos_err = *pos_err;
        }
        if *vel_err > worst_vel_err {
            worst_vel_name = name.clone();
            worst_vel_err = *vel_err;
        }

        let pos_ok = *pos_err < pos_tol_au;
        let vel_ok = *vel_err < vel_tol_au_day;

        if pos_ok && vel_ok {
            n_pass += 1;
        } else {
            n_fail += 1;
            eprintln!(
                "FAIL [{label}]: {name}: pos_err={pos_err:.2e} AU, vel_err={vel_err:.2e} AU/day"
            );
        }
    }

    eprintln!("\n--- Validation Summary [{label}] ---");
    eprintln!("Orbits tested: {}", results.len());
    eprintln!("Passed: {n_pass}, Failed: {n_fail}");
    eprintln!("Worst position error: {worst_pos_err:.2e} AU ({worst_pos_name})");
    eprintln!("Worst velocity error: {worst_vel_err:.2e} AU/day ({worst_vel_name})");
    eprintln!("Tolerances: pos < {pos_tol_au:.0e} AU, vel < {vel_tol_au_day:.0e} AU/day");

    assert_eq!(
        n_fail, 0,
        "[{label}] {n_fail} orbits exceeded tolerance (worst pos: {worst_pos_err:.2e} AU, worst vel: {worst_vel_err:.2e} AU/day)"
    );
}

#[test]
fn test_validation_against_python() {
    let Some((planets, asteroids)) = ephem_paths() else {
        eprintln!("Skipping: ASSIST_PLANETS_PATH / ASSIST_ASTEROIDS_PATH not set");
        return;
    };

    let Some(reference) = load_reference_data() else {
        eprintln!("Skipping: validation/reference_data.json not found");
        eprintln!("Generate it with: python validation/generate_reference.py");
        return;
    };

    eprintln!(
        "Reference data: {} orbits, Python REBOUND {} + ASSIST {}",
        reference.orbits.len(),
        reference.metadata.rebound_version,
        reference.metadata.assist_version,
    );
    eprintln!(
        "Propagation: {} days, {} epochs per orbit",
        reference.metadata.propagation_days, reference.metadata.n_epochs
    );

    let ephem =
        assist_rs::Ephemeris::from_paths(&planets, &asteroids).expect("Failed to load ephemeris");
    let data = assist_rs::AssistData::new(ephem);

    // Propagate all orbits in parallel using rayon.
    // Ephemeris is Send+Sync — each thread creates its own Simulation.
    run_validation_batch("gravity", &data, &reference.orbits, None);
}

#[test]
fn test_nongrav_validation_against_python() {
    let Some((planets, asteroids)) = ephem_paths() else {
        eprintln!("Skipping: ASSIST_PLANETS_PATH / ASSIST_ASTEROIDS_PATH not set");
        return;
    };

    let Some(reference) = load_reference_data() else {
        eprintln!("Skipping: validation/reference_data.json not found");
        return;
    };

    if reference.nongrav_orbits.is_empty() {
        eprintln!("Skipping: reference_data.json has no nongrav_orbits section");
        eprintln!("Regenerate it with: python validation/generate_reference.py");
        return;
    }

    let Some(nongrav) = reference.metadata.nongrav_params else {
        panic!("nongrav_orbits present but metadata.nongrav_params missing");
    };

    eprintln!(
        "Non-grav reference data: {} orbits with (A1, A2, A3) = ({:.2e}, {:.2e}, {:.2e})",
        reference.nongrav_orbits.len(),
        nongrav[0],
        nongrav[1],
        nongrav[2],
    );

    let ephem =
        assist_rs::Ephemeris::from_paths(&planets, &asteroids).expect("Failed to load ephemeris");
    let data = assist_rs::AssistData::new(ephem);

    run_validation_batch(
        "non-gravitational",
        &data,
        &reference.nongrav_orbits,
        Some(nongrav),
    );
}

#[test]
fn test_rayon_parallel_propagation() {
    // Test that rayon parallelization works correctly with shared Ephemeris.
    let Some((planets, asteroids)) = ephem_paths() else {
        eprintln!("Skipping: ephemeris not available");
        return;
    };

    let ephem =
        assist_rs::Ephemeris::from_paths(&planets, &asteroids).expect("Failed to load ephemeris");
    let data = assist_rs::AssistData::new(ephem);

    // Create 10 copies of a test orbit and propagate in parallel
    let orbit = assist_rs::Orbit::new(
        [
            -1.938_169_72,
            2.289_213_79,
            1.094_048_30,
            -0.008_744_54,
            -0.005_523_16,
            0.001_174_22,
        ],
        60000.0,
    );
    let targets = vec![orbit.epoch + 10.0, orbit.epoch + 20.0, orbit.epoch + 30.0];

    let n_orbits = 10;
    let results: Vec<_> = (0..n_orbits)
        .into_par_iter()
        .map(|_| {
            assist_rs::assist_propagate_single(
                &data,
                &orbit,
                &targets,
                false,
                &assist_rs::IntegratorConfig::default(),
            )
            .expect("Propagation failed")
        })
        .collect();

    // All results should be identical (same input, same library)
    assert_eq!(results.len(), n_orbits);
    for (i, orbit_results) in results.iter().enumerate().skip(1) {
        for (j, epoch_result) in orbit_results.iter().enumerate() {
            for k in 0..6 {
                assert!(
                    (results[0][j].state[k] - epoch_result.state[k]).abs() < 1e-15,
                    "Parallel result mismatch: orbit {i}, epoch {j}, element {k}"
                );
            }
        }
    }

    eprintln!("Parallel propagation: {n_orbits} identical orbits produced identical results");
}