ballistics_engine/
trajectory_integration.rs

1//! Advanced trajectory integration methods (RK4, RK45)
2//!
3//! This module provides production-grade numerical integration for ballistic trajectories:
4//! - RK4: 4th-order Runge-Kutta (fixed step)
5//! - RK45: Dormand-Prince adaptive method (same as scipy.integrate.solve_ivp)
6//!
7//! MBA-155: Upstreamed from ballistics_rust for shared use
8
9use nalgebra::{Vector3, Vector6};
10use std::collections::HashMap;
11
12use crate::derivatives::compute_derivatives;
13use crate::wind::WindSegment;
14use crate::BallisticInputs;
15use crate::DragModel;
16
17/// RK4 integration step
18fn rk4_step(state: &Vector6<f64>, t: f64, dt: f64, params: &TrajectoryParams) -> Vector6<f64> {
19    // RK4 integration
20    let k1 = compute_derivatives_vec(state, t, params);
21    let k2 = compute_derivatives_vec(&(state + dt * 0.5 * k1), t + dt * 0.5, params);
22    let k3 = compute_derivatives_vec(&(state + dt * 0.5 * k2), t + dt * 0.5, params);
23    let k4 = compute_derivatives_vec(&(state + dt * k3), t + dt, params);
24
25    state + (dt / 6.0) * (k1 + 2.0 * k2 + 2.0 * k3 + k4)
26}
27
28/// Adaptive RK45 integration step (Dormand-Prince method)
29fn rk45_step(
30    state: &Vector6<f64>,
31    t: f64,
32    dt: f64,
33    params: &TrajectoryParams,
34    tol: f64,
35) -> (Vector6<f64>, f64, f64) {
36    // Dormand-Prince coefficients (same as scipy.integrate.solve_ivp RK45)
37    const A21: f64 = 1.0 / 5.0;
38    const A31: f64 = 3.0 / 40.0;
39    const A32: f64 = 9.0 / 40.0;
40    const A41: f64 = 44.0 / 45.0;
41    const A42: f64 = -56.0 / 15.0;
42    const A43: f64 = 32.0 / 9.0;
43    const A51: f64 = 19372.0 / 6561.0;
44    const A52: f64 = -25360.0 / 2187.0;
45    const A53: f64 = 64448.0 / 6561.0;
46    const A54: f64 = -212.0 / 729.0;
47    const A61: f64 = 9017.0 / 3168.0;
48    const A62: f64 = -355.0 / 33.0;
49    const A63: f64 = 46732.0 / 5247.0;
50    const A64: f64 = 49.0 / 176.0;
51    const A65: f64 = -5103.0 / 18656.0;
52    const A71: f64 = 35.0 / 384.0;
53    const A73: f64 = 500.0 / 1113.0;
54    const A74: f64 = 125.0 / 192.0;
55    const A75: f64 = -2187.0 / 6784.0;
56    const A76: f64 = 11.0 / 84.0;
57
58    // 5th order coefficients
59    const B1: f64 = 35.0 / 384.0;
60    const B3: f64 = 500.0 / 1113.0;
61    const B4: f64 = 125.0 / 192.0;
62    const B5: f64 = -2187.0 / 6784.0;
63    const B6: f64 = 11.0 / 84.0;
64
65    // 4th order coefficients (for error estimation)
66    const B1_ERR: f64 = 5179.0 / 57600.0;
67    const B3_ERR: f64 = 7571.0 / 16695.0;
68    const B4_ERR: f64 = 393.0 / 640.0;
69    const B5_ERR: f64 = -92097.0 / 339200.0;
70    const B6_ERR: f64 = 187.0 / 2100.0;
71    const B7_ERR: f64 = 1.0 / 40.0;
72
73    // Compute stages
74    let k1 = compute_derivatives_vec(state, t, params);
75    let k2 = compute_derivatives_vec(&(state + dt * A21 * k1), t + dt * 0.2, params);
76    let k3 = compute_derivatives_vec(&(state + dt * (A31 * k1 + A32 * k2)), t + dt * 0.3, params);
77    let k4 = compute_derivatives_vec(
78        &(state + dt * (A41 * k1 + A42 * k2 + A43 * k3)),
79        t + dt * 0.8,
80        params,
81    );
82    let k5 = compute_derivatives_vec(
83        &(state + dt * (A51 * k1 + A52 * k2 + A53 * k3 + A54 * k4)),
84        t + dt * 8.0 / 9.0,
85        params,
86    );
87    let k6 = compute_derivatives_vec(
88        &(state + dt * (A61 * k1 + A62 * k2 + A63 * k3 + A64 * k4 + A65 * k5)),
89        t + dt,
90        params,
91    );
92    let k7 = compute_derivatives_vec(
93        &(state + dt * (A71 * k1 + A73 * k3 + A74 * k4 + A75 * k5 + A76 * k6)),
94        t + dt,
95        params,
96    );
97
98    // 5th order solution
99    let y_new = state + dt * (B1 * k1 + B3 * k3 + B4 * k4 + B5 * k5 + B6 * k6);
100
101    // 4th order solution for error estimate
102    let y_err = state
103        + dt * (B1_ERR * k1 + B3_ERR * k3 + B4_ERR * k4 + B5_ERR * k5 + B6_ERR * k6 + B7_ERR * k7);
104
105    // Error estimate
106    let error = (y_new - y_err).norm() / (1.0 + state.norm());
107
108    // Adaptive step size
109    let safety = 0.9;
110    let dt_new = if error < tol {
111        dt * safety * (tol / error).powf(0.2).min(2.0)
112    } else {
113        dt * safety * (tol / error).powf(0.25).max(0.1)
114    };
115
116    (y_new, dt_new, error)
117}
118
119/// Parameters for trajectory computation
120pub struct TrajectoryParams {
121    pub mass_kg: f64,
122    pub bc: f64,
123    pub drag_model: DragModel,
124    pub wind_segments: Vec<WindSegment>,
125    pub atmos_params: (f64, f64, f64, f64),
126    pub omega_vector: Option<Vector3<f64>>,
127    pub enable_spin_drift: bool,
128    pub enable_magnus: bool,
129    pub enable_coriolis: bool,
130    pub target_distance_m: f64, // Target horizontal distance in meters
131    pub enable_wind_shear: bool,
132    pub wind_shear_model: String,
133    pub shooter_altitude_m: f64,
134    pub is_twist_right: bool, // True for right-hand twist, false for left-hand
135    pub custom_drag_table: Option<crate::drag::DragTable>, // Custom Drag Model (CDM) data
136    pub bc_segments: Option<Vec<(f64, f64)>>, // Mach-based BC segments: (mach, bc)
137    pub use_bc_segments: bool, // Whether to use BC segment interpolation
138}
139
140/// Convert state to Vector6 and call compute_derivatives
141fn compute_derivatives_vec(
142    state: &Vector6<f64>,
143    t: f64,
144    params: &TrajectoryParams,
145) -> Vector6<f64> {
146    let pos = Vector3::new(state[0], state[1], state[2]);
147    let vel = Vector3::new(state[3], state[4], state[5]);
148
149    // Calculate wind at current position with shear support
150    let wind_vector = if !params.wind_segments.is_empty() {
151        if params.enable_wind_shear && params.wind_shear_model != "none" {
152            crate::wind_shear::get_wind_at_position(
153                &pos,
154                &params.wind_segments,
155                params.enable_wind_shear,
156                &params.wind_shear_model,
157                params.shooter_altitude_m,
158            )
159        } else {
160            // Simple constant wind (original implementation)
161            let seg = &params.wind_segments[0];
162            let wind_speed_mps = seg.0 * 0.2777778; // km/h to m/s
163            let wind_angle_rad = seg.1.to_radians();
164            // Z IS DOWNRANGE: x=lateral, y=vertical, z=downrange
165            Vector3::new(
166                -wind_speed_mps * wind_angle_rad.sin(), // x (lateral - crosswind component)
167                0.0,                                    // y (vertical)
168                -wind_speed_mps * wind_angle_rad.cos(), // z (downrange - head/tail component)
169            )
170        }
171    } else {
172        Vector3::zeros()
173    };
174
175    // Create a minimal BallisticInputs struct for the derivatives function
176    let inputs = BallisticInputs {
177        bc_value: params.bc,
178        bc_type: params.drag_model,
179        bullet_mass: params.mass_kg / 0.00006479891, // kg to grains
180        muzzle_velocity: vel.norm() * 3.28084,       // m/s to fps
181        bullet_diameter: 0.308,                      // default
182        bullet_length: 1.24,                         // default
183        twist_rate: 10.0,                            // default
184        is_twist_right: params.is_twist_right,
185        enable_advanced_effects: params.enable_spin_drift
186            || params.enable_magnus
187            || params.enable_coriolis,
188        altitude: params.atmos_params.0,
189        temperature: params.atmos_params.1,
190        pressure: params.atmos_params.2,
191        humidity: params.atmos_params.3,
192        tipoff_yaw: 0.0,
193        target_distance: 1000.0, // default
194        muzzle_angle: 0.0,
195        wind_speed: if !params.wind_segments.is_empty() {
196            params.wind_segments[0].0
197        } else {
198            0.0
199        },
200        wind_angle: if !params.wind_segments.is_empty() {
201            params.wind_segments[0].1
202        } else {
203            0.0
204        },
205        latitude: None,
206        shooting_angle: 0.0,
207        azimuth_angle: 0.0,
208        use_powder_sensitivity: false,
209        powder_temp_sensitivity: 0.0,
210        powder_temp: 59.0,
211        tipoff_decay_distance: 0.0,
212        ground_threshold: -1000.0,
213        bc_segments: params.bc_segments.clone(),
214        caliber_inches: 0.308,
215        weight_grains: params.mass_kg / 0.00006479891,
216        use_bc_segments: params.use_bc_segments,
217        bullet_id: None,
218        bc_segments_data: None,
219        use_enhanced_spin_drift: params.enable_spin_drift,
220        use_form_factor: false,
221        manufacturer: None,
222        bullet_model: None,
223        enable_wind_shear: false,
224        wind_shear_model: "none".to_string(),
225        use_cluster_bc: false,
226        bullet_cluster: None,
227
228        // Pass through custom drag table (CDM) from trajectory parameters
229        custom_drag_table: params.custom_drag_table.clone(),
230
231        bc_type_str: None,
232        enable_pitch_damping: false,
233        enable_precession_nutation: false,
234        use_rk4: true,
235        use_adaptive_rk45: false,
236        enable_trajectory_sampling: false,
237        sample_interval: 10.0,
238        sight_height: 0.0,
239        muzzle_height: 0.0,
240        target_height: 0.0,
241    };
242
243    // Call compute_derivatives - returns [f64; 6] directly
244    let deriv_result = compute_derivatives(
245        pos,
246        vel,
247        &inputs,
248        wind_vector,
249        params.atmos_params,
250        params.bc,
251        params.omega_vector,
252        t,
253    );
254
255    Vector6::new(
256        deriv_result[0],
257        deriv_result[1],
258        deriv_result[2],
259        deriv_result[3],
260        deriv_result[4],
261        deriv_result[5],
262    )
263}
264
265/// Main trajectory integration function
266pub fn integrate_trajectory(
267    initial_state: [f64; 6],
268    t_span: (f64, f64),
269    params: TrajectoryParams,
270    method: &str,
271    tolerance: f64,
272    max_step: f64,
273) -> Vec<(f64, Vector6<f64>)> {
274    let mut state = Vector6::new(
275        initial_state[0],
276        initial_state[1],
277        initial_state[2],
278        initial_state[3],
279        initial_state[4],
280        initial_state[5],
281    );
282
283    let mut t = t_span.0;
284    let t_end = t_span.1;
285    let mut dt = (t_end - t) / 1000.0; // Initial step size
286
287    let mut trajectory = Vec::with_capacity(10000);
288    trajectory.push((t, state));
289
290    match method {
291        "RK4" => {
292            // Fixed step RK4 with target detection
293            dt = dt.min(max_step).min(0.001); // Use smaller steps for accuracy
294
295            while t < t_end {
296                if t + dt > t_end {
297                    dt = t_end - t;
298                }
299
300                let new_state = rk4_step(&state, t, dt, &params);
301
302                // Check if we're about to pass the target (z is downrange)
303                if state[2] < params.target_distance_m && new_state[2] >= params.target_distance_m {
304                    // Interpolate to find exact target crossing
305                    let alpha = (params.target_distance_m - state[2]) / (new_state[2] - state[2]);
306                    let dt_to_target = dt * alpha;
307
308                    // Take a smaller step to reach target exactly
309                    let final_state = rk4_step(&state, t, dt_to_target, &params);
310
311                    // Ensure we don't overshoot
312                    let mut corrected_state = final_state;
313                    if corrected_state[2] > params.target_distance_m {
314                        corrected_state[2] = params.target_distance_m;
315                    }
316
317                    trajectory.push((t + dt_to_target, corrected_state));
318                    break; // Stop at target
319                }
320
321                state = new_state;
322                t += dt;
323                trajectory.push((t, state));
324
325                // Check if we've reached or passed the target
326                if state[2] >= params.target_distance_m {
327                    // z is downrange
328                    // Add final point exactly at target
329                    let mut final_state = state;
330                    final_state[2] = params.target_distance_m; // z is downrange
331                    trajectory.push((t, final_state));
332                    break;
333                }
334
335                // Check if bullet hit ground
336                if state[1] < -1000.0 {
337                    break;
338                }
339            }
340        }
341        "RK45" | _ => {
342            // Adaptive RK45 with better sampling
343            let mut last_save_z = 0.0; // z is downrange
344            let save_interval_m = params.target_distance_m / 50.0; // Save ~50 points minimum
345
346            // OPTIMIZATION: Adjust max step size when wind shear is enabled
347            // This improves numerical stability at long ranges
348            let effective_max_step =
349                if params.enable_wind_shear && params.wind_shear_model != "none" {
350                    // Use smaller steps for wind shear, but not TOO small
351                    if params.target_distance_m > 800.0 {
352                        0.01 // Smaller steps for long range with shear (10ms)
353                    } else {
354                        0.02 // Normal steps for medium range with shear (20ms)
355                    }
356                } else {
357                    max_step // Use provided max_step when no wind shear
358                };
359
360            // Set initial step size - ensure it's reasonable
361            dt = dt.min(effective_max_step).max(0.0001); // At least 0.1ms to avoid infinite loops
362
363            // Safety check: maximum iterations to prevent infinite loops
364            let max_iterations = 100000; // Should be more than enough for any realistic trajectory
365            let mut iteration_count = 0;
366
367            while t < t_end && iteration_count < max_iterations {
368                iteration_count += 1;
369
370                // Limit time step for better resolution
371                if t + dt > t_end {
372                    dt = t_end - t;
373                }
374
375                let (new_state, dt_new, _error) = rk45_step(&state, t, dt, &params, tolerance);
376
377                // Check if we're about to pass the target (z is downrange)
378                if state[2] < params.target_distance_m && new_state[2] >= params.target_distance_m {
379                    // Interpolate to find exact target crossing
380                    let alpha = (params.target_distance_m - state[2]) / (new_state[2] - state[2]);
381                    let dt_to_target = dt * alpha;
382
383                    // Take a smaller step to reach target exactly
384                    let (final_state, _, _) =
385                        rk45_step(&state, t, dt_to_target, &params, tolerance);
386
387                    // Make sure we don't overshoot
388                    let mut corrected_state = final_state;
389                    if corrected_state[2] > params.target_distance_m {
390                        corrected_state[2] = params.target_distance_m;
391                    }
392
393                    trajectory.push((t + dt_to_target, corrected_state));
394                    break; // Stop at target - no more points after this
395                }
396
397                // Update state
398                state = new_state;
399                t += dt;
400
401                // Save trajectory point if we've moved enough distance
402                if state[2] - last_save_z >= save_interval_m || state[2] >= params.target_distance_m
403                {
404                    // z is downrange
405                    trajectory.push((t, state));
406                    last_save_z = state[2];
407                }
408
409                // Limit dt for next step - ensure we get enough resolution
410                dt = dt_new.min(effective_max_step).max(0.0001); // Use effective max step, min 0.1ms
411
412                // Stop if we've reached the target
413                if state[2] >= params.target_distance_m {
414                    // z is downrange
415                    // Add final point at target distance
416                    let mut final_state = state;
417                    final_state[2] = params.target_distance_m; // z is downrange
418                    trajectory.push((t, final_state));
419                    break;
420                }
421
422                // Check if bullet hit ground
423                if state[1] < -1000.0 {
424                    break;
425                }
426            }
427
428            // Warn if we hit the iteration limit
429            if iteration_count >= max_iterations {
430                eprintln!(
431                    "WARNING: Trajectory integration hit maximum iteration limit ({} iterations)",
432                    max_iterations
433                );
434                eprintln!("  Final time: {}, Target time: {}", t, t_end);
435                eprintln!(
436                    "  Final position: z={}, Target: {}m",
437                    state[2], params.target_distance_m
438                );
439            }
440        }
441    }
442
443    trajectory
444}
445
446/// Python-exposed function for complete trajectory integration
447pub fn solve_trajectory_rust(
448    initial_state: [f64; 6],
449    t_span: (f64, f64),
450    mass_kg: f64,
451    bc: f64,
452    drag_model: DragModel,
453    wind_segments: Vec<WindSegment>,
454    atmos_params: (f64, f64, f64, f64),
455    omega_vector: Option<Vec<f64>>,
456    enable_spin_drift: bool,
457    enable_magnus: bool,
458    enable_coriolis: bool,
459    method: String,
460    tolerance: f64,
461    max_step: f64,
462    target_distance_m: f64,
463) -> Vec<HashMap<String, f64>> {
464    let omega_vec = omega_vector.map(|v| Vector3::new(v[0], v[1], v[2]));
465
466    let params = TrajectoryParams {
467        mass_kg,
468        bc,
469        drag_model,
470        wind_segments,
471        atmos_params,
472        omega_vector: omega_vec,
473        enable_spin_drift,
474        enable_magnus,
475        enable_coriolis,
476        target_distance_m,
477        enable_wind_shear: false, // Default for test function
478        wind_shear_model: "none".to_string(),
479        shooter_altitude_m: 0.0,
480        is_twist_right: true,    // Default for test function
481        custom_drag_table: None, // No CDM for test function
482        bc_segments: None,       // No BC segments for legacy function
483        use_bc_segments: false,
484    };
485
486    let trajectory =
487        integrate_trajectory(initial_state, t_span, params, &method, tolerance, max_step);
488
489    // Convert to Python-friendly format
490    trajectory
491        .into_iter()
492        .map(|(t, state)| {
493            let mut point = HashMap::new();
494            point.insert("t".to_string(), t);
495            point.insert("x".to_string(), state[0]);
496            point.insert("y".to_string(), state[1]);
497            point.insert("z".to_string(), state[2]);
498            point.insert("vx".to_string(), state[3]);
499            point.insert("vy".to_string(), state[4]);
500            point.insert("vz".to_string(), state[5]);
501            point
502        })
503        .collect()
504}
505
506#[cfg(test)]
507mod tests {
508    use super::*;
509
510    #[test]
511    fn test_integrate_trajectory_basic() {
512        // Same initial state as Python test: [x,y,z,vx,vy,vz]
513        // z=0 (downrange start), vz=821.52 (downrange velocity)
514        let initial_state = [0.0, -0.038, 0.0, 0.0, 48.61, 821.52];
515
516        let params = TrajectoryParams {
517            mass_kg: 0.01134, // 175 grains in kg
518            bc: 0.442,
519            drag_model: DragModel::G7,
520            wind_segments: vec![(0.0, 90.0, 914.4)],
521            atmos_params: (0.0, 59.0, 29.92, 0.0),
522            omega_vector: None,
523            enable_spin_drift: false,
524            enable_magnus: false,
525            enable_coriolis: false,
526            target_distance_m: 914.4, // 1000 yards in meters
527            enable_wind_shear: false,
528            wind_shear_model: "none".to_string(),
529            shooter_altitude_m: 0.0,
530            is_twist_right: true,
531            custom_drag_table: None,
532            bc_segments: None,
533            use_bc_segments: false,
534        };
535
536        println!("Running integrate_trajectory test...");
537        println!("Initial state: {:?}", initial_state);
538        println!("Target distance: {} m", params.target_distance_m);
539
540        let trajectory =
541            integrate_trajectory(initial_state, (0.0, 10.0), params, "RK45", 1e-6, 0.01);
542
543        println!("Trajectory has {} points", trajectory.len());
544
545        // Should have more than just initial point
546        assert!(
547            trajectory.len() > 1,
548            "Trajectory should have more than 1 point, but has {}",
549            trajectory.len()
550        );
551
552        // Check that we actually moved downrange
553        if let Some((_, final_state)) = trajectory.last() {
554            println!("Final state: z={}", final_state[2]);
555            assert!(
556                final_state[2] > 0.0,
557                "Final z should be positive (bullet moved downrange)"
558            );
559            assert!(
560                final_state[2] >= 900.0,
561                "Final z should be near target distance"
562            );
563        }
564    }
565}