ballistics_engine/
trajectory_integration.rs

1//! Advanced trajectory integration methods (RK4, RK45)
2//!
3//! This module provides production-grade numerical integration for ballistic trajectories:
4//! - RK4: 4th-order Runge-Kutta (fixed step)
5//! - RK45: Dormand-Prince adaptive method (same as scipy.integrate.solve_ivp)
6//!
7//! MBA-155: Upstreamed from ballistics_rust for shared use
8
9use nalgebra::{Vector3, Vector6};
10use std::collections::HashMap;
11
12use crate::derivatives::compute_derivatives;
13use crate::wind::WindSegment;
14use crate::DragModel;
15use crate::BallisticInputs;
16
17/// RK4 integration step
18fn rk4_step(
19    state: &Vector6<f64>,
20    t: f64,
21    dt: f64,
22    params: &TrajectoryParams,
23) -> Vector6<f64> {
24    // RK4 integration
25    let k1 = compute_derivatives_vec(state, t, params);
26    let k2 = compute_derivatives_vec(&(state + dt * 0.5 * k1), t + dt * 0.5, params);
27    let k3 = compute_derivatives_vec(&(state + dt * 0.5 * k2), t + dt * 0.5, params);
28    let k4 = compute_derivatives_vec(&(state + dt * k3), t + dt, params);
29
30    state + (dt / 6.0) * (k1 + 2.0 * k2 + 2.0 * k3 + k4)
31}
32
33/// Adaptive RK45 integration step (Dormand-Prince method)
34fn rk45_step(
35    state: &Vector6<f64>,
36    t: f64,
37    dt: f64,
38    params: &TrajectoryParams,
39    tol: f64,
40) -> (Vector6<f64>, f64, f64) {
41    // Dormand-Prince coefficients (same as scipy.integrate.solve_ivp RK45)
42    const A21: f64 = 1.0 / 5.0;
43    const A31: f64 = 3.0 / 40.0;
44    const A32: f64 = 9.0 / 40.0;
45    const A41: f64 = 44.0 / 45.0;
46    const A42: f64 = -56.0 / 15.0;
47    const A43: f64 = 32.0 / 9.0;
48    const A51: f64 = 19372.0 / 6561.0;
49    const A52: f64 = -25360.0 / 2187.0;
50    const A53: f64 = 64448.0 / 6561.0;
51    const A54: f64 = -212.0 / 729.0;
52    const A61: f64 = 9017.0 / 3168.0;
53    const A62: f64 = -355.0 / 33.0;
54    const A63: f64 = 46732.0 / 5247.0;
55    const A64: f64 = 49.0 / 176.0;
56    const A65: f64 = -5103.0 / 18656.0;
57    const A71: f64 = 35.0 / 384.0;
58    const A73: f64 = 500.0 / 1113.0;
59    const A74: f64 = 125.0 / 192.0;
60    const A75: f64 = -2187.0 / 6784.0;
61    const A76: f64 = 11.0 / 84.0;
62
63    // 5th order coefficients
64    const B1: f64 = 35.0 / 384.0;
65    const B3: f64 = 500.0 / 1113.0;
66    const B4: f64 = 125.0 / 192.0;
67    const B5: f64 = -2187.0 / 6784.0;
68    const B6: f64 = 11.0 / 84.0;
69
70    // 4th order coefficients (for error estimation)
71    const B1_ERR: f64 = 5179.0 / 57600.0;
72    const B3_ERR: f64 = 7571.0 / 16695.0;
73    const B4_ERR: f64 = 393.0 / 640.0;
74    const B5_ERR: f64 = -92097.0 / 339200.0;
75    const B6_ERR: f64 = 187.0 / 2100.0;
76    const B7_ERR: f64 = 1.0 / 40.0;
77
78    // Compute stages
79    let k1 = compute_derivatives_vec(state, t, params);
80    let k2 = compute_derivatives_vec(&(state + dt * A21 * k1), t + dt * 0.2, params);
81    let k3 = compute_derivatives_vec(&(state + dt * (A31 * k1 + A32 * k2)), t + dt * 0.3, params);
82    let k4 = compute_derivatives_vec(&(state + dt * (A41 * k1 + A42 * k2 + A43 * k3)), t + dt * 0.8, params);
83    let k5 = compute_derivatives_vec(&(state + dt * (A51 * k1 + A52 * k2 + A53 * k3 + A54 * k4)), t + dt * 8.0/9.0, params);
84    let k6 = compute_derivatives_vec(&(state + dt * (A61 * k1 + A62 * k2 + A63 * k3 + A64 * k4 + A65 * k5)), t + dt, params);
85    let k7 = compute_derivatives_vec(&(state + dt * (A71 * k1 + A73 * k3 + A74 * k4 + A75 * k5 + A76 * k6)), t + dt, params);
86
87    // 5th order solution
88    let y_new = state + dt * (B1 * k1 + B3 * k3 + B4 * k4 + B5 * k5 + B6 * k6);
89
90    // 4th order solution for error estimate
91    let y_err = state + dt * (B1_ERR * k1 + B3_ERR * k3 + B4_ERR * k4 + B5_ERR * k5 + B6_ERR * k6 + B7_ERR * k7);
92
93    // Error estimate
94    let error = (y_new - y_err).norm() / (1.0 + state.norm());
95
96    // Adaptive step size
97    let safety = 0.9;
98    let dt_new = if error < tol {
99        dt * safety * (tol / error).powf(0.2).min(2.0)
100    } else {
101        dt * safety * (tol / error).powf(0.25).max(0.1)
102    };
103
104    (y_new, dt_new, error)
105}
106
107/// Parameters for trajectory computation
108pub struct TrajectoryParams {
109    pub mass_kg: f64,
110    pub bc: f64,
111    pub drag_model: DragModel,
112    pub wind_segments: Vec<WindSegment>,
113    pub atmos_params: (f64, f64, f64, f64),
114    pub omega_vector: Option<Vector3<f64>>,
115    pub enable_spin_drift: bool,
116    pub enable_magnus: bool,
117    pub enable_coriolis: bool,
118    pub target_distance_m: f64,  // Target horizontal distance in meters
119    pub enable_wind_shear: bool,
120    pub wind_shear_model: String,
121    pub shooter_altitude_m: f64,
122    pub is_twist_right: bool,  // True for right-hand twist, false for left-hand
123    pub custom_drag_table: Option<crate::drag::DragTable>,  // Custom Drag Model (CDM) data
124    pub bc_segments: Option<Vec<(f64, f64)>>,  // Mach-based BC segments: (mach, bc)
125    pub use_bc_segments: bool,  // Whether to use BC segment interpolation
126}
127
128/// Convert state to Vector6 and call compute_derivatives
129fn compute_derivatives_vec(
130    state: &Vector6<f64>,
131    t: f64,
132    params: &TrajectoryParams,
133) -> Vector6<f64> {
134
135    let pos = Vector3::new(state[0], state[1], state[2]);
136    let vel = Vector3::new(state[3], state[4], state[5]);
137
138    // Calculate wind at current position with shear support
139    let wind_vector = if !params.wind_segments.is_empty() {
140        if params.enable_wind_shear && params.wind_shear_model != "none" {
141            crate::wind_shear::get_wind_at_position(
142                &pos,
143                &params.wind_segments,
144                params.enable_wind_shear,
145                &params.wind_shear_model,
146                params.shooter_altitude_m,
147            )
148        } else {
149            // Simple constant wind (original implementation)
150            let seg = &params.wind_segments[0];
151            let wind_speed_mps = seg.0 * 0.2777778; // km/h to m/s
152            let wind_angle_rad = seg.1.to_radians();
153            // Z IS DOWNRANGE: x=lateral, y=vertical, z=downrange
154            Vector3::new(
155                -wind_speed_mps * wind_angle_rad.sin(),  // x (lateral - crosswind component)
156                0.0,                                      // y (vertical)
157                -wind_speed_mps * wind_angle_rad.cos(),  // z (downrange - head/tail component)
158            )
159        }
160    } else {
161        Vector3::zeros()
162    };
163
164    // Create a minimal BallisticInputs struct for the derivatives function
165    let inputs = BallisticInputs {
166        bc_value: params.bc,
167        bc_type: params.drag_model.clone(),
168        bullet_mass: params.mass_kg / 0.00006479891, // kg to grains
169        muzzle_velocity: vel.norm() * 3.28084, // m/s to fps
170        bullet_diameter: 0.308, // default
171        bullet_length: 1.24, // default
172        twist_rate: 10.0, // default
173        is_twist_right: params.is_twist_right,
174        enable_advanced_effects: params.enable_spin_drift || params.enable_magnus || params.enable_coriolis,
175        altitude: params.atmos_params.0,
176        temperature: params.atmos_params.1,
177        pressure: params.atmos_params.2,
178        humidity: params.atmos_params.3,
179        tipoff_yaw: 0.0,
180        target_distance: 1000.0, // default
181        muzzle_angle: 0.0,
182        wind_speed: if !params.wind_segments.is_empty() { params.wind_segments[0].0 } else { 0.0 },
183        wind_angle: if !params.wind_segments.is_empty() { params.wind_segments[0].1 } else { 0.0 },
184        latitude: None,
185        shooting_angle: 0.0,
186        azimuth_angle: 0.0,
187        use_powder_sensitivity: false,
188        powder_temp_sensitivity: 0.0,
189        powder_temp: 59.0,
190        tipoff_decay_distance: 0.0,
191        ground_threshold: -1000.0,
192        bc_segments: params.bc_segments.clone(),
193        caliber_inches: 0.308,
194        weight_grains: params.mass_kg / 0.00006479891,
195        use_bc_segments: params.use_bc_segments,
196        bullet_id: None,
197        bc_segments_data: None,
198        use_enhanced_spin_drift: params.enable_spin_drift,
199        use_form_factor: false,
200        manufacturer: None,
201        bullet_model: None,
202        enable_wind_shear: false,
203        wind_shear_model: "none".to_string(),
204        use_cluster_bc: false,
205        bullet_cluster: None,
206
207        // Pass through custom drag table (CDM) from trajectory parameters
208        custom_drag_table: params.custom_drag_table.clone(),
209
210        bc_type_str: None,
211        enable_pitch_damping: false,
212        enable_precession_nutation: false,
213        use_rk4: true,
214        use_adaptive_rk45: false,
215        enable_trajectory_sampling: false,
216        sample_interval: 10.0,
217        sight_height: 0.0,
218        muzzle_height: 0.0,
219        target_height: 0.0,
220    };
221
222    // Call compute_derivatives - returns [f64; 6] directly
223    let deriv_result = compute_derivatives(
224        pos,
225        vel,
226        &inputs,
227        wind_vector,
228        params.atmos_params,
229        params.bc,
230        params.omega_vector,
231        t,
232    );
233
234    Vector6::new(
235        deriv_result[0], deriv_result[1], deriv_result[2],
236        deriv_result[3], deriv_result[4], deriv_result[5],
237    )
238}
239
240/// Main trajectory integration function
241pub fn integrate_trajectory(
242    initial_state: [f64; 6],
243    t_span: (f64, f64),
244    params: TrajectoryParams,
245    method: &str,
246    tolerance: f64,
247    max_step: f64,
248) -> Vec<(f64, Vector6<f64>)> {
249    let mut state = Vector6::new(
250        initial_state[0], initial_state[1], initial_state[2],
251        initial_state[3], initial_state[4], initial_state[5],
252    );
253
254    let mut t = t_span.0;
255    let t_end = t_span.1;
256    let mut dt = (t_end - t) / 1000.0; // Initial step size
257
258    let mut trajectory = Vec::with_capacity(10000);
259    trajectory.push((t, state.clone()));
260
261    match method {
262        "RK4" => {
263            // Fixed step RK4 with target detection
264            dt = dt.min(max_step).min(0.001); // Use smaller steps for accuracy
265
266            while t < t_end {
267                if t + dt > t_end {
268                    dt = t_end - t;
269                }
270
271                let new_state = rk4_step(&state, t, dt, &params);
272
273                // Check if we're about to pass the target (z is downrange)
274                if state[2] < params.target_distance_m && new_state[2] >= params.target_distance_m {
275                    // Interpolate to find exact target crossing
276                    let alpha = (params.target_distance_m - state[2]) / (new_state[2] - state[2]);
277                    let dt_to_target = dt * alpha;
278
279                    // Take a smaller step to reach target exactly
280                    let final_state = rk4_step(&state, t, dt_to_target, &params);
281
282                    // Ensure we don't overshoot
283                    let mut corrected_state = final_state;
284                    if corrected_state[2] > params.target_distance_m {
285                        corrected_state[2] = params.target_distance_m;
286                    }
287
288                    trajectory.push((t + dt_to_target, corrected_state));
289                    break;  // Stop at target
290                }
291
292                state = new_state;
293                t += dt;
294                trajectory.push((t, state.clone()));
295
296                // Check if we've reached or passed the target
297                if state[2] >= params.target_distance_m {  // z is downrange
298                    // Add final point exactly at target
299                    let mut final_state = state;
300                    final_state[2] = params.target_distance_m;  // z is downrange
301                    trajectory.push((t, final_state));
302                    break;
303                }
304
305                // Check if bullet hit ground
306                if state[1] < -1000.0 {
307                    break;
308                }
309            }
310        }
311        "RK45" | _ => {
312            // Adaptive RK45 with better sampling
313            let mut last_save_z = 0.0;  // z is downrange
314            let save_interval_m = params.target_distance_m / 50.0; // Save ~50 points minimum
315
316            // OPTIMIZATION: Adjust max step size when wind shear is enabled
317            // This improves numerical stability at long ranges
318            let effective_max_step = if params.enable_wind_shear && params.wind_shear_model != "none" {
319                // Use smaller steps for wind shear, but not TOO small
320                if params.target_distance_m > 800.0 {
321                    0.01  // Smaller steps for long range with shear (10ms)
322                } else {
323                    0.02  // Normal steps for medium range with shear (20ms)
324                }
325            } else {
326                max_step  // Use provided max_step when no wind shear
327            };
328
329            // Set initial step size - ensure it's reasonable
330            dt = dt.min(effective_max_step).max(0.0001);  // At least 0.1ms to avoid infinite loops
331
332            // Safety check: maximum iterations to prevent infinite loops
333            let max_iterations = 100000;  // Should be more than enough for any realistic trajectory
334            let mut iteration_count = 0;
335
336            while t < t_end && iteration_count < max_iterations {
337                iteration_count += 1;
338
339                // Limit time step for better resolution
340                if t + dt > t_end {
341                    dt = t_end - t;
342                }
343
344                let (new_state, dt_new, _error) = rk45_step(&state, t, dt, &params, tolerance);
345
346                // Check if we're about to pass the target (z is downrange)
347                if state[2] < params.target_distance_m && new_state[2] >= params.target_distance_m {
348                    // Interpolate to find exact target crossing
349                    let alpha = (params.target_distance_m - state[2]) / (new_state[2] - state[2]);
350                    let dt_to_target = dt * alpha;
351
352                    // Take a smaller step to reach target exactly
353                    let (final_state, _, _) = rk45_step(&state, t, dt_to_target, &params, tolerance);
354
355                    // Make sure we don't overshoot
356                    let mut corrected_state = final_state;
357                    if corrected_state[2] > params.target_distance_m {
358                        corrected_state[2] = params.target_distance_m;
359                    }
360
361                    trajectory.push((t + dt_to_target, corrected_state));
362                    break;  // Stop at target - no more points after this
363                }
364
365                // Update state
366                state = new_state;
367                t += dt;
368
369                // Save trajectory point if we've moved enough distance
370                if state[2] - last_save_z >= save_interval_m || state[2] >= params.target_distance_m {  // z is downrange
371                    trajectory.push((t, state.clone()));
372                    last_save_z = state[2];
373                }
374
375                // Limit dt for next step - ensure we get enough resolution
376                dt = dt_new.min(effective_max_step).max(0.0001); // Use effective max step, min 0.1ms
377
378                // Stop if we've reached the target
379                if state[2] >= params.target_distance_m {  // z is downrange
380                    // Add final point at target distance
381                    let mut final_state = state;
382                    final_state[2] = params.target_distance_m;  // z is downrange
383                    trajectory.push((t, final_state));
384                    break;
385                }
386
387                // Check if bullet hit ground
388                if state[1] < -1000.0 {
389                    break;
390                }
391            }
392
393            // Warn if we hit the iteration limit
394            if iteration_count >= max_iterations {
395                eprintln!("WARNING: Trajectory integration hit maximum iteration limit ({} iterations)", max_iterations);
396                eprintln!("  Final time: {}, Target time: {}", t, t_end);
397                eprintln!("  Final position: z={}, Target: {}m", state[2], params.target_distance_m);
398            }
399        }
400    }
401
402    trajectory
403}
404
405/// Python-exposed function for complete trajectory integration
406pub fn solve_trajectory_rust(
407    initial_state: [f64; 6],
408    t_span: (f64, f64),
409    mass_kg: f64,
410    bc: f64,
411    drag_model: DragModel,
412    wind_segments: Vec<WindSegment>,
413    atmos_params: (f64, f64, f64, f64),
414    omega_vector: Option<Vec<f64>>,
415    enable_spin_drift: bool,
416    enable_magnus: bool,
417    enable_coriolis: bool,
418    method: String,
419    tolerance: f64,
420    max_step: f64,
421    target_distance_m: f64,
422) -> Vec<HashMap<String, f64>> {
423    let omega_vec = omega_vector.map(|v| Vector3::new(v[0], v[1], v[2]));
424
425    let params = TrajectoryParams {
426        mass_kg,
427        bc,
428        drag_model,
429        wind_segments,
430        atmos_params,
431        omega_vector: omega_vec,
432        enable_spin_drift,
433        enable_magnus,
434        enable_coriolis,
435        target_distance_m,
436        enable_wind_shear: false,  // Default for test function
437        wind_shear_model: "none".to_string(),
438        shooter_altitude_m: 0.0,
439        is_twist_right: true,  // Default for test function
440        custom_drag_table: None,  // No CDM for test function
441        bc_segments: None,  // No BC segments for legacy function
442        use_bc_segments: false,
443    };
444
445    let trajectory = integrate_trajectory(
446        initial_state,
447        t_span,
448        params,
449        &method,
450        tolerance,
451        max_step,
452    );
453
454    // Convert to Python-friendly format
455    trajectory.into_iter().map(|(t, state)| {
456        let mut point = HashMap::new();
457        point.insert("t".to_string(), t);
458        point.insert("x".to_string(), state[0]);
459        point.insert("y".to_string(), state[1]);
460        point.insert("z".to_string(), state[2]);
461        point.insert("vx".to_string(), state[3]);
462        point.insert("vy".to_string(), state[4]);
463        point.insert("vz".to_string(), state[5]);
464        point
465    }).collect()
466}
467
468#[cfg(test)]
469mod tests {
470    use super::*;
471
472    #[test]
473    fn test_integrate_trajectory_basic() {
474        // Same initial state as Python test: [x,y,z,vx,vy,vz]
475        // z=0 (downrange start), vz=821.52 (downrange velocity)
476        let initial_state = [0.0, -0.038, 0.0, 0.0, 48.61, 821.52];
477
478        let params = TrajectoryParams {
479            mass_kg: 0.01134, // 175 grains in kg
480            bc: 0.442,
481            drag_model: DragModel::G7,
482            wind_segments: vec![(0.0, 90.0, 914.4)],
483            atmos_params: (0.0, 59.0, 29.92, 0.0),
484            omega_vector: None,
485            enable_spin_drift: false,
486            enable_magnus: false,
487            enable_coriolis: false,
488            target_distance_m: 914.4, // 1000 yards in meters
489            enable_wind_shear: false,
490            wind_shear_model: "none".to_string(),
491            shooter_altitude_m: 0.0,
492            is_twist_right: true,
493            custom_drag_table: None,
494            bc_segments: None,
495            use_bc_segments: false,
496        };
497
498        println!("Running integrate_trajectory test...");
499        println!("Initial state: {:?}", initial_state);
500        println!("Target distance: {} m", params.target_distance_m);
501
502        let trajectory = integrate_trajectory(
503            initial_state,
504            (0.0, 10.0),
505            params,
506            "RK45",
507            1e-6,
508            0.01,
509        );
510
511        println!("Trajectory has {} points", trajectory.len());
512
513        // Should have more than just initial point
514        assert!(trajectory.len() > 1, "Trajectory should have more than 1 point, but has {}", trajectory.len());
515
516        // Check that we actually moved downrange
517        if let Some((_, final_state)) = trajectory.last() {
518            println!("Final state: z={}", final_state[2]);
519            assert!(final_state[2] > 0.0, "Final z should be positive (bullet moved downrange)");
520            assert!(final_state[2] >= 900.0, "Final z should be near target distance");
521        }
522    }
523}