use nalgebra::{Vector3, Vector6};
use std::collections::HashMap;
use crate::derivatives::compute_derivatives;
use crate::wind::WindSegment;
use crate::BallisticInputs;
use crate::DragModel;
fn rk4_step(state: &Vector6<f64>, t: f64, dt: f64, params: &TrajectoryParams) -> Vector6<f64> {
let k1 = compute_derivatives_vec(state, t, params);
let k2 = compute_derivatives_vec(&(state + dt * 0.5 * k1), t + dt * 0.5, params);
let k3 = compute_derivatives_vec(&(state + dt * 0.5 * k2), t + dt * 0.5, params);
let k4 = compute_derivatives_vec(&(state + dt * k3), t + dt, params);
state + (dt / 6.0) * (k1 + 2.0 * k2 + 2.0 * k3 + k4)
}
fn rk45_step(
state: &Vector6<f64>,
t: f64,
dt: f64,
params: &TrajectoryParams,
tol: f64,
) -> (Vector6<f64>, f64, f64) {
const A21: f64 = 1.0 / 5.0;
const A31: f64 = 3.0 / 40.0;
const A32: f64 = 9.0 / 40.0;
const A41: f64 = 44.0 / 45.0;
const A42: f64 = -56.0 / 15.0;
const A43: f64 = 32.0 / 9.0;
const A51: f64 = 19372.0 / 6561.0;
const A52: f64 = -25360.0 / 2187.0;
const A53: f64 = 64448.0 / 6561.0;
const A54: f64 = -212.0 / 729.0;
const A61: f64 = 9017.0 / 3168.0;
const A62: f64 = -355.0 / 33.0;
const A63: f64 = 46732.0 / 5247.0;
const A64: f64 = 49.0 / 176.0;
const A65: f64 = -5103.0 / 18656.0;
const A71: f64 = 35.0 / 384.0;
const A73: f64 = 500.0 / 1113.0;
const A74: f64 = 125.0 / 192.0;
const A75: f64 = -2187.0 / 6784.0;
const A76: f64 = 11.0 / 84.0;
const B1: f64 = 35.0 / 384.0;
const B3: f64 = 500.0 / 1113.0;
const B4: f64 = 125.0 / 192.0;
const B5: f64 = -2187.0 / 6784.0;
const B6: f64 = 11.0 / 84.0;
const B1_ERR: f64 = 5179.0 / 57600.0;
const B3_ERR: f64 = 7571.0 / 16695.0;
const B4_ERR: f64 = 393.0 / 640.0;
const B5_ERR: f64 = -92097.0 / 339200.0;
const B6_ERR: f64 = 187.0 / 2100.0;
const B7_ERR: f64 = 1.0 / 40.0;
let k1 = compute_derivatives_vec(state, t, params);
let k2 = compute_derivatives_vec(&(state + dt * A21 * k1), t + dt * 0.2, params);
let k3 = compute_derivatives_vec(&(state + dt * (A31 * k1 + A32 * k2)), t + dt * 0.3, params);
let k4 = compute_derivatives_vec(
&(state + dt * (A41 * k1 + A42 * k2 + A43 * k3)),
t + dt * 0.8,
params,
);
let k5 = compute_derivatives_vec(
&(state + dt * (A51 * k1 + A52 * k2 + A53 * k3 + A54 * k4)),
t + dt * 8.0 / 9.0,
params,
);
let k6 = compute_derivatives_vec(
&(state + dt * (A61 * k1 + A62 * k2 + A63 * k3 + A64 * k4 + A65 * k5)),
t + dt,
params,
);
let k7 = compute_derivatives_vec(
&(state + dt * (A71 * k1 + A73 * k3 + A74 * k4 + A75 * k5 + A76 * k6)),
t + dt,
params,
);
let y_new = state + dt * (B1 * k1 + B3 * k3 + B4 * k4 + B5 * k5 + B6 * k6);
let y_err = state
+ dt * (B1_ERR * k1 + B3_ERR * k3 + B4_ERR * k4 + B5_ERR * k5 + B6_ERR * k6 + B7_ERR * k7);
let error = (y_new - y_err).norm() / (1.0 + state.norm());
let safety = 0.9;
let dt_new = if error < tol {
dt * safety * (tol / error).powf(0.2).min(2.0)
} else {
dt * safety * (tol / error).powf(0.25).max(0.1)
};
(y_new, dt_new, error)
}
pub struct TrajectoryParams {
pub mass_kg: f64,
pub bc: f64,
pub drag_model: DragModel,
pub wind_segments: Vec<WindSegment>,
pub atmos_params: (f64, f64, f64, f64),
pub omega_vector: Option<Vector3<f64>>,
pub enable_spin_drift: bool,
pub enable_magnus: bool,
pub enable_coriolis: bool,
pub target_distance_m: f64, pub enable_wind_shear: bool,
pub wind_shear_model: String,
pub shooter_altitude_m: f64,
pub is_twist_right: bool, pub custom_drag_table: Option<crate::drag::DragTable>, pub bc_segments: Option<Vec<(f64, f64)>>, pub use_bc_segments: bool, }
fn compute_derivatives_vec(
state: &Vector6<f64>,
t: f64,
params: &TrajectoryParams,
) -> Vector6<f64> {
let pos = Vector3::new(state[0], state[1], state[2]);
let vel = Vector3::new(state[3], state[4], state[5]);
let wind_vector = if !params.wind_segments.is_empty() {
if params.enable_wind_shear && params.wind_shear_model != "none" {
crate::wind_shear::get_wind_at_position(
&pos,
¶ms.wind_segments,
params.enable_wind_shear,
¶ms.wind_shear_model,
params.shooter_altitude_m,
)
} else {
let seg = ¶ms.wind_segments[0];
let wind_speed_mps = seg.0 * 0.2777778; let wind_angle_rad = seg.1.to_radians();
Vector3::new(
-wind_speed_mps * wind_angle_rad.sin(), 0.0, -wind_speed_mps * wind_angle_rad.cos(), )
}
} else {
Vector3::zeros()
};
let inputs = BallisticInputs {
bc_value: params.bc,
bc_type: params.drag_model,
bullet_mass: params.mass_kg / 0.00006479891, muzzle_velocity: vel.norm() * 3.28084, bullet_diameter: 0.308, bullet_length: 1.24, twist_rate: 10.0, is_twist_right: params.is_twist_right,
enable_advanced_effects: params.enable_spin_drift
|| params.enable_magnus
|| params.enable_coriolis,
altitude: params.atmos_params.0,
temperature: params.atmos_params.1,
pressure: params.atmos_params.2,
humidity: params.atmos_params.3,
tipoff_yaw: 0.0,
target_distance: 1000.0, muzzle_angle: 0.0,
wind_speed: if !params.wind_segments.is_empty() {
params.wind_segments[0].0
} else {
0.0
},
wind_angle: if !params.wind_segments.is_empty() {
params.wind_segments[0].1
} else {
0.0
},
latitude: None,
shooting_angle: 0.0,
azimuth_angle: 0.0,
use_powder_sensitivity: false,
powder_temp_sensitivity: 0.0,
powder_temp: 59.0,
tipoff_decay_distance: 0.0,
ground_threshold: -1000.0,
bc_segments: params.bc_segments.clone(),
caliber_inches: 0.308,
weight_grains: params.mass_kg / 0.00006479891,
use_bc_segments: params.use_bc_segments,
bullet_id: None,
bc_segments_data: None,
use_enhanced_spin_drift: params.enable_spin_drift,
use_form_factor: false,
manufacturer: None,
bullet_model: None,
enable_wind_shear: false,
wind_shear_model: "none".to_string(),
use_cluster_bc: false,
bullet_cluster: None,
custom_drag_table: params.custom_drag_table.clone(),
bc_type_str: None,
enable_pitch_damping: false,
enable_precession_nutation: false,
use_rk4: true,
use_adaptive_rk45: false,
enable_trajectory_sampling: false,
sample_interval: 10.0,
sight_height: 0.0,
muzzle_height: 0.0,
target_height: 0.0,
};
let deriv_result = compute_derivatives(
pos,
vel,
&inputs,
wind_vector,
params.atmos_params,
params.bc,
params.omega_vector,
t,
);
Vector6::new(
deriv_result[0],
deriv_result[1],
deriv_result[2],
deriv_result[3],
deriv_result[4],
deriv_result[5],
)
}
pub fn integrate_trajectory(
initial_state: [f64; 6],
t_span: (f64, f64),
params: TrajectoryParams,
method: &str,
tolerance: f64,
max_step: f64,
) -> Vec<(f64, Vector6<f64>)> {
let mut state = Vector6::new(
initial_state[0],
initial_state[1],
initial_state[2],
initial_state[3],
initial_state[4],
initial_state[5],
);
let mut t = t_span.0;
let t_end = t_span.1;
let mut dt = (t_end - t) / 1000.0;
let mut trajectory = Vec::with_capacity(10000);
trajectory.push((t, state));
match method {
"RK4" => {
dt = dt.min(max_step).min(0.001);
while t < t_end {
if t + dt > t_end {
dt = t_end - t;
}
let new_state = rk4_step(&state, t, dt, ¶ms);
if state[2] < params.target_distance_m && new_state[2] >= params.target_distance_m {
let alpha = (params.target_distance_m - state[2]) / (new_state[2] - state[2]);
let dt_to_target = dt * alpha;
let final_state = rk4_step(&state, t, dt_to_target, ¶ms);
let mut corrected_state = final_state;
if corrected_state[2] > params.target_distance_m {
corrected_state[2] = params.target_distance_m;
}
trajectory.push((t + dt_to_target, corrected_state));
break; }
state = new_state;
t += dt;
trajectory.push((t, state));
if state[2] >= params.target_distance_m {
let mut final_state = state;
final_state[2] = params.target_distance_m; trajectory.push((t, final_state));
break;
}
if state[1] < -1000.0 {
break;
}
}
}
"RK45" | _ => {
let mut last_save_z = 0.0; let save_interval_m = params.target_distance_m / 50.0;
let effective_max_step =
if params.enable_wind_shear && params.wind_shear_model != "none" {
if params.target_distance_m > 800.0 {
0.01 } else {
0.02 }
} else {
max_step };
dt = dt.min(effective_max_step).max(0.0001);
let max_iterations = 100000; let mut iteration_count = 0;
while t < t_end && iteration_count < max_iterations {
iteration_count += 1;
if t + dt > t_end {
dt = t_end - t;
}
let (new_state, dt_new, _error) = rk45_step(&state, t, dt, ¶ms, tolerance);
if state[2] < params.target_distance_m && new_state[2] >= params.target_distance_m {
let alpha = (params.target_distance_m - state[2]) / (new_state[2] - state[2]);
let dt_to_target = dt * alpha;
let (final_state, _, _) =
rk45_step(&state, t, dt_to_target, ¶ms, tolerance);
let mut corrected_state = final_state;
if corrected_state[2] > params.target_distance_m {
corrected_state[2] = params.target_distance_m;
}
trajectory.push((t + dt_to_target, corrected_state));
break; }
state = new_state;
t += dt;
if state[2] - last_save_z >= save_interval_m || state[2] >= params.target_distance_m
{
trajectory.push((t, state));
last_save_z = state[2];
}
dt = dt_new.min(effective_max_step).max(0.0001);
if state[2] >= params.target_distance_m {
let mut final_state = state;
final_state[2] = params.target_distance_m; trajectory.push((t, final_state));
break;
}
if state[1] < -1000.0 {
break;
}
}
if iteration_count >= max_iterations {
eprintln!(
"WARNING: Trajectory integration hit maximum iteration limit ({} iterations)",
max_iterations
);
eprintln!(" Final time: {}, Target time: {}", t, t_end);
eprintln!(
" Final position: z={}, Target: {}m",
state[2], params.target_distance_m
);
}
}
}
trajectory
}
pub fn solve_trajectory_rust(
initial_state: [f64; 6],
t_span: (f64, f64),
mass_kg: f64,
bc: f64,
drag_model: DragModel,
wind_segments: Vec<WindSegment>,
atmos_params: (f64, f64, f64, f64),
omega_vector: Option<Vec<f64>>,
enable_spin_drift: bool,
enable_magnus: bool,
enable_coriolis: bool,
method: String,
tolerance: f64,
max_step: f64,
target_distance_m: f64,
) -> Vec<HashMap<String, f64>> {
let omega_vec = omega_vector.map(|v| Vector3::new(v[0], v[1], v[2]));
let params = TrajectoryParams {
mass_kg,
bc,
drag_model,
wind_segments,
atmos_params,
omega_vector: omega_vec,
enable_spin_drift,
enable_magnus,
enable_coriolis,
target_distance_m,
enable_wind_shear: false, wind_shear_model: "none".to_string(),
shooter_altitude_m: 0.0,
is_twist_right: true, custom_drag_table: None, bc_segments: None, use_bc_segments: false,
};
let trajectory =
integrate_trajectory(initial_state, t_span, params, &method, tolerance, max_step);
trajectory
.into_iter()
.map(|(t, state)| {
let mut point = HashMap::new();
point.insert("t".to_string(), t);
point.insert("x".to_string(), state[0]);
point.insert("y".to_string(), state[1]);
point.insert("z".to_string(), state[2]);
point.insert("vx".to_string(), state[3]);
point.insert("vy".to_string(), state[4]);
point.insert("vz".to_string(), state[5]);
point
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_params(target_distance_m: f64) -> TrajectoryParams {
TrajectoryParams {
mass_kg: 0.01134, bc: 0.442,
drag_model: DragModel::G7,
wind_segments: vec![],
atmos_params: (0.0, 59.0, 29.92, 0.0),
omega_vector: None,
enable_spin_drift: false,
enable_magnus: false,
enable_coriolis: false,
target_distance_m,
enable_wind_shear: false,
wind_shear_model: "none".to_string(),
shooter_altitude_m: 0.0,
is_twist_right: true,
custom_drag_table: None,
bc_segments: None,
use_bc_segments: false,
}
}
#[test]
fn test_integrate_trajectory_basic() {
let initial_state = [0.0, -0.038, 0.0, 0.0, 48.61, 821.52];
let params = TrajectoryParams {
mass_kg: 0.01134, bc: 0.442,
drag_model: DragModel::G7,
wind_segments: vec![(0.0, 90.0, 914.4)],
atmos_params: (0.0, 59.0, 29.92, 0.0),
omega_vector: None,
enable_spin_drift: false,
enable_magnus: false,
enable_coriolis: false,
target_distance_m: 914.4, enable_wind_shear: false,
wind_shear_model: "none".to_string(),
shooter_altitude_m: 0.0,
is_twist_right: true,
custom_drag_table: None,
bc_segments: None,
use_bc_segments: false,
};
println!("Running integrate_trajectory test...");
println!("Initial state: {:?}", initial_state);
println!("Target distance: {} m", params.target_distance_m);
let trajectory =
integrate_trajectory(initial_state, (0.0, 10.0), params, "RK45", 1e-6, 0.01);
println!("Trajectory has {} points", trajectory.len());
assert!(
trajectory.len() > 1,
"Trajectory should have more than 1 point, but has {}",
trajectory.len()
);
if let Some((_, final_state)) = trajectory.last() {
println!("Final state: z={}", final_state[2]);
assert!(
final_state[2] > 0.0,
"Final z should be positive (bullet moved downrange)"
);
assert!(
final_state[2] >= 900.0,
"Final z should be near target distance"
);
}
}
#[test]
fn test_rk4_vs_rk45_consistency() {
let initial_state = [0.0, 0.0, 0.0, 0.0, 30.0, 800.0];
let target_distance = 500.0;
let params_rk4 = create_test_params(target_distance);
let params_rk45 = create_test_params(target_distance);
let trajectory_rk4 =
integrate_trajectory(initial_state, (0.0, 5.0), params_rk4, "RK4", 1e-6, 0.001);
let trajectory_rk45 =
integrate_trajectory(initial_state, (0.0, 5.0), params_rk45, "RK45", 1e-6, 0.01);
assert!(!trajectory_rk4.is_empty());
assert!(!trajectory_rk45.is_empty());
let (_, final_rk4) = trajectory_rk4.last().unwrap();
let (_, final_rk45) = trajectory_rk45.last().unwrap();
let rk4_z = final_rk4[2];
let rk45_z = final_rk45[2];
let diff_percent = ((rk4_z - rk45_z) / rk45_z).abs() * 100.0;
assert!(
diff_percent < 1.0,
"RK4 and RK45 final positions differ by {}%: RK4={}, RK45={}",
diff_percent,
rk4_z,
rk45_z
);
}
#[test]
fn test_ground_impact_detection() {
let initial_state = [0.0, 100.0, 0.0, 0.0, -50.0, 300.0];
let mut params = create_test_params(10000.0); params.target_distance_m = 10000.0;
let trajectory =
integrate_trajectory(initial_state, (0.0, 20.0), params, "RK45", 1e-6, 0.01);
let (_, final_state) = trajectory.last().unwrap();
assert!(
final_state[1] <= -900.0,
"Should hit ground, but y={}",
final_state[1]
);
assert!(
final_state[2] < 10000.0,
"Should not reach target, but z={}",
final_state[2]
);
}
#[test]
fn test_target_distance_reached() {
let initial_state = [0.0, 0.0, 0.0, 0.0, 20.0, 800.0];
let target_distance = 300.0;
let params = create_test_params(target_distance);
let trajectory =
integrate_trajectory(initial_state, (0.0, 5.0), params, "RK45", 1e-6, 0.01);
let (_, final_state) = trajectory.last().unwrap();
assert!(
(final_state[2] - target_distance).abs() < 1.0,
"Should reach target at {}m, but stopped at {}m",
target_distance,
final_state[2]
);
}
#[test]
fn test_wind_affects_trajectory() {
let initial_state = [0.0, 0.0, 0.0, 0.0, 30.0, 800.0];
let target_distance = 500.0;
let params_no_wind = create_test_params(target_distance);
let mut params_headwind = create_test_params(target_distance);
params_headwind.wind_segments = vec![(72.0, 0.0, 500.0)];
let trajectory_no_wind =
integrate_trajectory(initial_state, (0.0, 5.0), params_no_wind, "RK45", 1e-6, 0.01);
let trajectory_headwind =
integrate_trajectory(initial_state, (0.0, 5.0), params_headwind, "RK45", 1e-6, 0.01);
assert!(!trajectory_no_wind.is_empty(), "No-wind trajectory should complete");
assert!(!trajectory_headwind.is_empty(), "Headwind trajectory should complete");
let (time_no_wind, final_no_wind) = trajectory_no_wind.last().unwrap();
let (time_headwind, final_headwind) = trajectory_headwind.last().unwrap();
let drop_no_wind = final_no_wind[1];
let drop_headwind = final_headwind[1];
println!("No wind: time={}, drop={}", time_no_wind, drop_no_wind);
println!("Headwind: time={}, drop={}", time_headwind, drop_headwind);
assert!(
(final_no_wind[2] - target_distance).abs() < 10.0,
"No-wind should reach target"
);
assert!(
(final_headwind[2] - target_distance).abs() < 10.0,
"Headwind should reach target"
);
}
#[test]
fn test_solve_trajectory_rust_output_format() {
let initial_state = [0.0, 0.0, 0.0, 0.0, 30.0, 800.0];
let result = solve_trajectory_rust(
initial_state,
(0.0, 2.0),
0.01134, 0.442, DragModel::G7, vec![], (0.0, 59.0, 29.92, 0.0), None, false, false, false, "RK45".to_string(), 1e-6, 0.01, 500.0, );
assert!(!result.is_empty());
let first_point = &result[0];
assert!(first_point.contains_key("t"));
assert!(first_point.contains_key("x"));
assert!(first_point.contains_key("y"));
assert!(first_point.contains_key("z"));
assert!(first_point.contains_key("vx"));
assert!(first_point.contains_key("vy"));
assert!(first_point.contains_key("vz"));
}
#[test]
fn test_left_vs_right_twist() {
let initial_state = [0.0, 0.0, 0.0, 0.0, 30.0, 800.0];
let target_distance = 500.0;
let mut params_right = create_test_params(target_distance);
params_right.is_twist_right = true;
params_right.enable_spin_drift = true;
let mut params_left = create_test_params(target_distance);
params_left.is_twist_right = false;
params_left.enable_spin_drift = true;
let trajectory_right =
integrate_trajectory(initial_state, (0.0, 5.0), params_right, "RK45", 1e-6, 0.01);
let trajectory_left =
integrate_trajectory(initial_state, (0.0, 5.0), params_left, "RK45", 1e-6, 0.01);
assert!(!trajectory_right.is_empty());
assert!(!trajectory_left.is_empty());
let (_, final_right) = trajectory_right.last().unwrap();
let (_, final_left) = trajectory_left.last().unwrap();
assert!((final_right[2] - final_left[2]).abs() < 10.0);
}
}