use std::path::PathBuf;
use rayon::prelude::*;
#[derive(serde::Deserialize)]
struct ReferenceData {
metadata: Metadata,
orbits: Vec<OrbitEntry>,
#[serde(default)]
nongrav_orbits: Vec<OrbitEntry>,
}
#[derive(serde::Deserialize)]
struct Metadata {
rebound_version: String,
assist_version: String,
propagation_days: f64,
n_epochs: usize,
#[allow(dead_code)]
jd_ref: f64,
#[allow(dead_code)]
cos_eps: f64,
#[allow(dead_code)]
sin_eps: f64,
#[serde(default)]
nongrav_params: Option<[f64; 3]>,
}
#[derive(serde::Deserialize, Clone)]
struct OrbitEntry {
object_id: String,
epoch_mjd: f64,
initial_state: [f64; 6],
propagated: Vec<PropagatedEntry>,
}
#[derive(serde::Deserialize, Clone)]
struct PropagatedEntry {
epoch: f64,
state: [f64; 6],
}
fn ephem_paths() -> Option<(PathBuf, PathBuf)> {
let planets = std::env::var("ASSIST_PLANETS_PATH").ok()?;
let asteroids = std::env::var("ASSIST_ASTEROIDS_PATH").ok()?;
Some((PathBuf::from(planets), PathBuf::from(asteroids)))
}
fn load_reference_data() -> Option<ReferenceData> {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("validation")
.join("reference_data.json");
if !path.exists() {
return None;
}
let data = std::fs::read_to_string(&path).ok()?;
serde_json::from_str(&data).ok()
}
fn validate_orbit(
data: &assist_rs::AssistData,
entry: &OrbitEntry,
non_grav: Option<[f64; 3]>,
) -> (String, f64, f64) {
let orbit = match non_grav {
Some([a1, a2, a3]) => assist_rs::Orbit::with_non_grav(
entry.initial_state,
entry.epoch_mjd,
assist_rs::NonGravParams::new(a1, a2, a3),
),
None => assist_rs::Orbit::new(entry.initial_state, entry.epoch_mjd),
};
let target_epochs: Vec<f64> = entry.propagated.iter().map(|p| p.epoch).collect();
let results = assist_rs::assist_propagate_single(
data,
&orbit,
&target_epochs,
false,
&assist_rs::IntegratorConfig::default(),
)
.unwrap_or_else(|e| panic!("Propagation failed for {}: {e}", entry.object_id));
assert_eq!(
results.len(),
entry.propagated.len(),
"Epoch count mismatch for {}",
entry.object_id
);
let mut max_pos_err = 0.0f64;
let mut max_vel_err = 0.0f64;
for (rust_result, py_ref) in results.iter().zip(entry.propagated.iter()) {
let pos_err = ((rust_result.state[0] - py_ref.state[0]).powi(2)
+ (rust_result.state[1] - py_ref.state[1]).powi(2)
+ (rust_result.state[2] - py_ref.state[2]).powi(2))
.sqrt();
let vel_err = ((rust_result.state[3] - py_ref.state[3]).powi(2)
+ (rust_result.state[4] - py_ref.state[4]).powi(2)
+ (rust_result.state[5] - py_ref.state[5]).powi(2))
.sqrt();
max_pos_err = max_pos_err.max(pos_err);
max_vel_err = max_vel_err.max(vel_err);
}
(entry.object_id.clone(), max_pos_err, max_vel_err)
}
fn run_validation_batch(
label: &str,
data: &assist_rs::AssistData,
entries: &[OrbitEntry],
non_grav: Option<[f64; 3]>,
) {
let pos_tol_au = 1e-10;
let vel_tol_au_day = 1e-12;
let results: Vec<(String, f64, f64)> = entries
.par_iter()
.map(|entry| validate_orbit(data, entry, non_grav))
.collect();
let mut n_pass = 0;
let mut n_fail = 0;
let mut worst_pos_err = 0.0f64;
let mut worst_vel_err = 0.0f64;
let mut worst_pos_name = String::new();
let mut worst_vel_name = String::new();
for (name, pos_err, vel_err) in &results {
if *pos_err > worst_pos_err {
worst_pos_name = name.clone();
worst_pos_err = *pos_err;
}
if *vel_err > worst_vel_err {
worst_vel_name = name.clone();
worst_vel_err = *vel_err;
}
let pos_ok = *pos_err < pos_tol_au;
let vel_ok = *vel_err < vel_tol_au_day;
if pos_ok && vel_ok {
n_pass += 1;
} else {
n_fail += 1;
eprintln!(
"FAIL [{label}]: {name}: pos_err={pos_err:.2e} AU, vel_err={vel_err:.2e} AU/day"
);
}
}
eprintln!("\n--- Validation Summary [{label}] ---");
eprintln!("Orbits tested: {}", results.len());
eprintln!("Passed: {n_pass}, Failed: {n_fail}");
eprintln!("Worst position error: {worst_pos_err:.2e} AU ({worst_pos_name})");
eprintln!("Worst velocity error: {worst_vel_err:.2e} AU/day ({worst_vel_name})");
eprintln!("Tolerances: pos < {pos_tol_au:.0e} AU, vel < {vel_tol_au_day:.0e} AU/day");
assert_eq!(
n_fail, 0,
"[{label}] {n_fail} orbits exceeded tolerance (worst pos: {worst_pos_err:.2e} AU, worst vel: {worst_vel_err:.2e} AU/day)"
);
}
#[test]
fn test_validation_against_python() {
let Some((planets, asteroids)) = ephem_paths() else {
eprintln!("Skipping: ASSIST_PLANETS_PATH / ASSIST_ASTEROIDS_PATH not set");
return;
};
let Some(reference) = load_reference_data() else {
eprintln!("Skipping: validation/reference_data.json not found");
eprintln!("Generate it with: python validation/generate_reference.py");
return;
};
eprintln!(
"Reference data: {} orbits, Python REBOUND {} + ASSIST {}",
reference.orbits.len(),
reference.metadata.rebound_version,
reference.metadata.assist_version,
);
eprintln!(
"Propagation: {} days, {} epochs per orbit",
reference.metadata.propagation_days, reference.metadata.n_epochs
);
let ephem =
assist_rs::Ephemeris::from_paths(&planets, &asteroids).expect("Failed to load ephemeris");
let data = assist_rs::AssistData::new(ephem);
run_validation_batch("gravity", &data, &reference.orbits, None);
}
#[test]
fn test_nongrav_validation_against_python() {
let Some((planets, asteroids)) = ephem_paths() else {
eprintln!("Skipping: ASSIST_PLANETS_PATH / ASSIST_ASTEROIDS_PATH not set");
return;
};
let Some(reference) = load_reference_data() else {
eprintln!("Skipping: validation/reference_data.json not found");
return;
};
if reference.nongrav_orbits.is_empty() {
eprintln!("Skipping: reference_data.json has no nongrav_orbits section");
eprintln!("Regenerate it with: python validation/generate_reference.py");
return;
}
let Some(nongrav) = reference.metadata.nongrav_params else {
panic!("nongrav_orbits present but metadata.nongrav_params missing");
};
eprintln!(
"Non-grav reference data: {} orbits with (A1, A2, A3) = ({:.2e}, {:.2e}, {:.2e})",
reference.nongrav_orbits.len(),
nongrav[0],
nongrav[1],
nongrav[2],
);
let ephem =
assist_rs::Ephemeris::from_paths(&planets, &asteroids).expect("Failed to load ephemeris");
let data = assist_rs::AssistData::new(ephem);
run_validation_batch(
"non-gravitational",
&data,
&reference.nongrav_orbits,
Some(nongrav),
);
}
#[test]
fn test_rayon_parallel_propagation() {
let Some((planets, asteroids)) = ephem_paths() else {
eprintln!("Skipping: ephemeris not available");
return;
};
let ephem =
assist_rs::Ephemeris::from_paths(&planets, &asteroids).expect("Failed to load ephemeris");
let data = assist_rs::AssistData::new(ephem);
let orbit = assist_rs::Orbit::new(
[
-1.938_169_72,
2.289_213_79,
1.094_048_30,
-0.008_744_54,
-0.005_523_16,
0.001_174_22,
],
60000.0,
);
let targets = vec![orbit.epoch + 10.0, orbit.epoch + 20.0, orbit.epoch + 30.0];
let n_orbits = 10;
let results: Vec<_> = (0..n_orbits)
.into_par_iter()
.map(|_| {
assist_rs::assist_propagate_single(
&data,
&orbit,
&targets,
false,
&assist_rs::IntegratorConfig::default(),
)
.expect("Propagation failed")
})
.collect();
assert_eq!(results.len(), n_orbits);
for (i, orbit_results) in results.iter().enumerate().skip(1) {
for (j, epoch_result) in orbit_results.iter().enumerate() {
for k in 0..6 {
assert!(
(results[0][j].state[k] - epoch_result.state[k]).abs() < 1e-15,
"Parallel result mismatch: orbit {i}, epoch {j}, element {k}"
);
}
}
}
eprintln!("Parallel propagation: {n_orbits} identical orbits produced identical results");
}