ballistics_engine/
trajectory_integration.rs

1//! Advanced trajectory integration methods (RK4, RK45)
2//!
3//! This module provides production-grade numerical integration for ballistic trajectories:
4//! - RK4: 4th-order Runge-Kutta (fixed step)
5//! - RK45: Dormand-Prince adaptive method (same as scipy.integrate.solve_ivp)
6//!
7//! MBA-155: Upstreamed from ballistics_rust for shared use
8
9use nalgebra::{Vector3, Vector6};
10use std::collections::HashMap;
11
12use crate::derivatives::compute_derivatives;
13use crate::wind::WindSegment;
14use crate::DragModel;
15use crate::BallisticInputs;
16
17/// RK4 integration step
18fn rk4_step(
19    state: &Vector6<f64>,
20    t: f64,
21    dt: f64,
22    params: &TrajectoryParams,
23) -> Vector6<f64> {
24    // RK4 integration
25    let k1 = compute_derivatives_vec(state, t, params);
26    let k2 = compute_derivatives_vec(&(state + dt * 0.5 * k1), t + dt * 0.5, params);
27    let k3 = compute_derivatives_vec(&(state + dt * 0.5 * k2), t + dt * 0.5, params);
28    let k4 = compute_derivatives_vec(&(state + dt * k3), t + dt, params);
29
30    state + (dt / 6.0) * (k1 + 2.0 * k2 + 2.0 * k3 + k4)
31}
32
33/// Adaptive RK45 integration step (Dormand-Prince method)
34fn rk45_step(
35    state: &Vector6<f64>,
36    t: f64,
37    dt: f64,
38    params: &TrajectoryParams,
39    tol: f64,
40) -> (Vector6<f64>, f64, f64) {
41    // Dormand-Prince coefficients (same as scipy.integrate.solve_ivp RK45)
42    const A21: f64 = 1.0 / 5.0;
43    const A31: f64 = 3.0 / 40.0;
44    const A32: f64 = 9.0 / 40.0;
45    const A41: f64 = 44.0 / 45.0;
46    const A42: f64 = -56.0 / 15.0;
47    const A43: f64 = 32.0 / 9.0;
48    const A51: f64 = 19372.0 / 6561.0;
49    const A52: f64 = -25360.0 / 2187.0;
50    const A53: f64 = 64448.0 / 6561.0;
51    const A54: f64 = -212.0 / 729.0;
52    const A61: f64 = 9017.0 / 3168.0;
53    const A62: f64 = -355.0 / 33.0;
54    const A63: f64 = 46732.0 / 5247.0;
55    const A64: f64 = 49.0 / 176.0;
56    const A65: f64 = -5103.0 / 18656.0;
57    const A71: f64 = 35.0 / 384.0;
58    const A73: f64 = 500.0 / 1113.0;
59    const A74: f64 = 125.0 / 192.0;
60    const A75: f64 = -2187.0 / 6784.0;
61    const A76: f64 = 11.0 / 84.0;
62
63    // 5th order coefficients
64    const B1: f64 = 35.0 / 384.0;
65    const B3: f64 = 500.0 / 1113.0;
66    const B4: f64 = 125.0 / 192.0;
67    const B5: f64 = -2187.0 / 6784.0;
68    const B6: f64 = 11.0 / 84.0;
69
70    // 4th order coefficients (for error estimation)
71    const B1_ERR: f64 = 5179.0 / 57600.0;
72    const B3_ERR: f64 = 7571.0 / 16695.0;
73    const B4_ERR: f64 = 393.0 / 640.0;
74    const B5_ERR: f64 = -92097.0 / 339200.0;
75    const B6_ERR: f64 = 187.0 / 2100.0;
76    const B7_ERR: f64 = 1.0 / 40.0;
77
78    // Compute stages
79    let k1 = compute_derivatives_vec(state, t, params);
80    let k2 = compute_derivatives_vec(&(state + dt * A21 * k1), t + dt * 0.2, params);
81    let k3 = compute_derivatives_vec(&(state + dt * (A31 * k1 + A32 * k2)), t + dt * 0.3, params);
82    let k4 = compute_derivatives_vec(&(state + dt * (A41 * k1 + A42 * k2 + A43 * k3)), t + dt * 0.8, params);
83    let k5 = compute_derivatives_vec(&(state + dt * (A51 * k1 + A52 * k2 + A53 * k3 + A54 * k4)), t + dt * 8.0/9.0, params);
84    let k6 = compute_derivatives_vec(&(state + dt * (A61 * k1 + A62 * k2 + A63 * k3 + A64 * k4 + A65 * k5)), t + dt, params);
85    let k7 = compute_derivatives_vec(&(state + dt * (A71 * k1 + A73 * k3 + A74 * k4 + A75 * k5 + A76 * k6)), t + dt, params);
86
87    // 5th order solution
88    let y_new = state + dt * (B1 * k1 + B3 * k3 + B4 * k4 + B5 * k5 + B6 * k6);
89
90    // 4th order solution for error estimate
91    let y_err = state + dt * (B1_ERR * k1 + B3_ERR * k3 + B4_ERR * k4 + B5_ERR * k5 + B6_ERR * k6 + B7_ERR * k7);
92
93    // Error estimate
94    let error = (y_new - y_err).norm() / (1.0 + state.norm());
95
96    // Adaptive step size
97    let safety = 0.9;
98    let dt_new = if error < tol {
99        dt * safety * (tol / error).powf(0.2).min(2.0)
100    } else {
101        dt * safety * (tol / error).powf(0.25).max(0.1)
102    };
103
104    (y_new, dt_new, error)
105}
106
107/// Parameters for trajectory computation
108pub struct TrajectoryParams {
109    pub mass_kg: f64,
110    pub bc: f64,
111    pub drag_model: DragModel,
112    pub wind_segments: Vec<WindSegment>,
113    pub atmos_params: (f64, f64, f64, f64),
114    pub omega_vector: Option<Vector3<f64>>,
115    pub enable_spin_drift: bool,
116    pub enable_magnus: bool,
117    pub enable_coriolis: bool,
118    pub target_distance_m: f64,  // Target horizontal distance in meters
119    pub enable_wind_shear: bool,
120    pub wind_shear_model: String,
121    pub shooter_altitude_m: f64,
122    pub is_twist_right: bool,  // True for right-hand twist, false for left-hand
123    pub custom_drag_table: Option<crate::drag::DragTable>,  // Custom Drag Model (CDM) data
124}
125
126/// Convert state to Vector6 and call compute_derivatives
127fn compute_derivatives_vec(
128    state: &Vector6<f64>,
129    t: f64,
130    params: &TrajectoryParams,
131) -> Vector6<f64> {
132
133    let pos = Vector3::new(state[0], state[1], state[2]);
134    let vel = Vector3::new(state[3], state[4], state[5]);
135
136    // Calculate wind at current position with shear support
137    let wind_vector = if !params.wind_segments.is_empty() {
138        if params.enable_wind_shear && params.wind_shear_model != "none" {
139            crate::wind_shear::get_wind_at_position(
140                &pos,
141                &params.wind_segments,
142                params.enable_wind_shear,
143                &params.wind_shear_model,
144                params.shooter_altitude_m,
145            )
146        } else {
147            // Simple constant wind (original implementation)
148            let seg = &params.wind_segments[0];
149            let wind_speed_mps = seg.0 * 0.2777778; // km/h to m/s
150            let wind_angle_rad = seg.1.to_radians();
151            // X IS DOWNRANGE: x=downrange, y=vertical, z=lateral
152            Vector3::new(
153                -wind_speed_mps * wind_angle_rad.cos(),  // x (downrange - head/tail component)
154                0.0,                                      // y (vertical)
155                -wind_speed_mps * wind_angle_rad.sin(),  // z (lateral - crosswind component)
156            )
157        }
158    } else {
159        Vector3::zeros()
160    };
161
162    // Create a minimal BallisticInputs struct for the derivatives function
163    let inputs = BallisticInputs {
164        bc_value: params.bc,
165        bc_type: params.drag_model.clone(),
166        bullet_mass: params.mass_kg / 0.00006479891, // kg to grains
167        muzzle_velocity: vel.norm() * 3.28084, // m/s to fps
168        bullet_diameter: 0.308, // default
169        bullet_length: 1.24, // default
170        twist_rate: 10.0, // default
171        is_twist_right: params.is_twist_right,
172        enable_advanced_effects: params.enable_spin_drift || params.enable_magnus || params.enable_coriolis,
173        altitude: params.atmos_params.0,
174        temperature: params.atmos_params.1,
175        pressure: params.atmos_params.2,
176        humidity: params.atmos_params.3,
177        tipoff_yaw: 0.0,
178        target_distance: 1000.0, // default
179        muzzle_angle: 0.0,
180        wind_speed: if !params.wind_segments.is_empty() { params.wind_segments[0].0 } else { 0.0 },
181        wind_angle: if !params.wind_segments.is_empty() { params.wind_segments[0].1 } else { 0.0 },
182        latitude: None,
183        shooting_angle: 0.0,
184        azimuth_angle: 0.0,
185        use_powder_sensitivity: false,
186        powder_temp_sensitivity: 0.0,
187        powder_temp: 59.0,
188        tipoff_decay_distance: 0.0,
189        ground_threshold: -1000.0,
190        bc_segments: None,
191        caliber_inches: 0.308,
192        weight_grains: params.mass_kg / 0.00006479891,
193        use_bc_segments: false,
194        bullet_id: None,
195        bc_segments_data: None,
196        use_enhanced_spin_drift: params.enable_spin_drift,
197        use_form_factor: false,
198        manufacturer: None,
199        bullet_model: None,
200        enable_wind_shear: false,
201        wind_shear_model: "none".to_string(),
202        use_cluster_bc: false,
203        bullet_cluster: None,
204
205        // Pass through custom drag table (CDM) from trajectory parameters
206        custom_drag_table: params.custom_drag_table.clone(),
207
208        bc_type_str: None,
209        enable_pitch_damping: false,
210        enable_precession_nutation: false,
211        use_rk4: true,
212        use_adaptive_rk45: false,
213        enable_trajectory_sampling: false,
214        sample_interval: 10.0,
215        sight_height: 0.0,
216        muzzle_height: 0.0,
217        target_height: 0.0,
218    };
219
220    // Call compute_derivatives - returns [f64; 6] directly
221    let deriv_result = compute_derivatives(
222        pos,
223        vel,
224        &inputs,
225        wind_vector,
226        params.atmos_params,
227        params.bc,
228        params.omega_vector,
229        t,
230    );
231
232    Vector6::new(
233        deriv_result[0], deriv_result[1], deriv_result[2],
234        deriv_result[3], deriv_result[4], deriv_result[5],
235    )
236}
237
238/// Main trajectory integration function
239pub fn integrate_trajectory(
240    initial_state: [f64; 6],
241    t_span: (f64, f64),
242    params: TrajectoryParams,
243    method: &str,
244    tolerance: f64,
245    max_step: f64,
246) -> Vec<(f64, Vector6<f64>)> {
247    let mut state = Vector6::new(
248        initial_state[0], initial_state[1], initial_state[2],
249        initial_state[3], initial_state[4], initial_state[5],
250    );
251
252    let mut t = t_span.0;
253    let t_end = t_span.1;
254    let mut dt = (t_end - t) / 1000.0; // Initial step size
255
256    let mut trajectory = Vec::with_capacity(10000);
257    trajectory.push((t, state.clone()));
258
259    match method {
260        "RK4" => {
261            // Fixed step RK4 with target detection
262            dt = dt.min(max_step).min(0.001); // Use smaller steps for accuracy
263
264            while t < t_end {
265                if t + dt > t_end {
266                    dt = t_end - t;
267                }
268
269                let new_state = rk4_step(&state, t, dt, &params);
270
271                // Check if we're about to pass the target (x is downrange)
272                if state[0] < params.target_distance_m && new_state[0] >= params.target_distance_m {
273                    // Interpolate to find exact target crossing
274                    let alpha = (params.target_distance_m - state[0]) / (new_state[0] - state[0]);
275                    let dt_to_target = dt * alpha;
276
277                    // Take a smaller step to reach target exactly
278                    let final_state = rk4_step(&state, t, dt_to_target, &params);
279
280                    // Ensure we don't overshoot
281                    let mut corrected_state = final_state;
282                    if corrected_state[0] > params.target_distance_m {
283                        corrected_state[0] = params.target_distance_m;
284                    }
285
286                    trajectory.push((t + dt_to_target, corrected_state));
287                    break;  // Stop at target
288                }
289
290                state = new_state;
291                t += dt;
292                trajectory.push((t, state.clone()));
293
294                // Check if we've reached or passed the target
295                if state[0] >= params.target_distance_m {  // x is downrange
296                    // Add final point exactly at target
297                    let mut final_state = state;
298                    final_state[0] = params.target_distance_m;  // x is downrange
299                    trajectory.push((t, final_state));
300                    break;
301                }
302
303                // Check if bullet hit ground
304                if state[1] < -1000.0 {
305                    break;
306                }
307            }
308        }
309        "RK45" | _ => {
310            // Adaptive RK45 with better sampling
311            let mut last_save_x = 0.0;
312            let save_interval_m = params.target_distance_m / 50.0; // Save ~50 points minimum
313
314            // OPTIMIZATION: Adjust max step size when wind shear is enabled
315            // This improves numerical stability at long ranges
316            let effective_max_step = if params.enable_wind_shear && params.wind_shear_model != "none" {
317                // Use smaller steps for wind shear, but not TOO small
318                if params.target_distance_m > 800.0 {
319                    0.01  // Smaller steps for long range with shear (10ms)
320                } else {
321                    0.02  // Normal steps for medium range with shear (20ms)
322                }
323            } else {
324                max_step  // Use provided max_step when no wind shear
325            };
326
327            // Set initial step size - ensure it's reasonable
328            dt = dt.min(effective_max_step).max(0.0001);  // At least 0.1ms to avoid infinite loops
329
330            // Safety check: maximum iterations to prevent infinite loops
331            let max_iterations = 100000;  // Should be more than enough for any realistic trajectory
332            let mut iteration_count = 0;
333
334            while t < t_end && iteration_count < max_iterations {
335                iteration_count += 1;
336
337                // Limit time step for better resolution
338                if t + dt > t_end {
339                    dt = t_end - t;
340                }
341
342                let (new_state, dt_new, _error) = rk45_step(&state, t, dt, &params, tolerance);
343
344                // Check if we're about to pass the target (x is downrange)
345                if state[0] < params.target_distance_m && new_state[0] >= params.target_distance_m {
346                    // Interpolate to find exact target crossing
347                    let alpha = (params.target_distance_m - state[0]) / (new_state[0] - state[0]);
348                    let dt_to_target = dt * alpha;
349
350                    // Take a smaller step to reach target exactly
351                    let (final_state, _, _) = rk45_step(&state, t, dt_to_target, &params, tolerance);
352
353                    // Make sure we don't overshoot
354                    let mut corrected_state = final_state;
355                    if corrected_state[0] > params.target_distance_m {
356                        corrected_state[0] = params.target_distance_m;
357                    }
358
359                    trajectory.push((t + dt_to_target, corrected_state));
360                    break;  // Stop at target - no more points after this
361                }
362
363                // Update state
364                state = new_state;
365                t += dt;
366
367                // Save trajectory point if we've moved enough distance
368                if state[0] - last_save_x >= save_interval_m || state[0] >= params.target_distance_m {  // x is downrange
369                    trajectory.push((t, state.clone()));
370                    last_save_x = state[0];
371                }
372
373                // Limit dt for next step - ensure we get enough resolution
374                dt = dt_new.min(effective_max_step).max(0.0001); // Use effective max step, min 0.1ms
375
376                // Stop if we've reached the target
377                if state[0] >= params.target_distance_m {  // x is downrange
378                    // Add final point at target distance
379                    let mut final_state = state;
380                    final_state[0] = params.target_distance_m;  // x is downrange
381                    trajectory.push((t, final_state));
382                    break;
383                }
384
385                // Check if bullet hit ground
386                if state[1] < -1000.0 {
387                    break;
388                }
389            }
390
391            // Warn if we hit the iteration limit
392            if iteration_count >= max_iterations {
393                eprintln!("WARNING: Trajectory integration hit maximum iteration limit ({} iterations)", max_iterations);
394                eprintln!("  Final time: {}, Target time: {}", t, t_end);
395                eprintln!("  Final position: x={}, Target: {}m", state[0], params.target_distance_m);
396            }
397        }
398    }
399
400    trajectory
401}
402
403/// Python-exposed function for complete trajectory integration
404pub fn solve_trajectory_rust(
405    initial_state: [f64; 6],
406    t_span: (f64, f64),
407    mass_kg: f64,
408    bc: f64,
409    drag_model: DragModel,
410    wind_segments: Vec<WindSegment>,
411    atmos_params: (f64, f64, f64, f64),
412    omega_vector: Option<Vec<f64>>,
413    enable_spin_drift: bool,
414    enable_magnus: bool,
415    enable_coriolis: bool,
416    method: String,
417    tolerance: f64,
418    max_step: f64,
419    target_distance_m: f64,
420) -> Vec<HashMap<String, f64>> {
421    let omega_vec = omega_vector.map(|v| Vector3::new(v[0], v[1], v[2]));
422
423    let params = TrajectoryParams {
424        mass_kg,
425        bc,
426        drag_model,
427        wind_segments,
428        atmos_params,
429        omega_vector: omega_vec,
430        enable_spin_drift,
431        enable_magnus,
432        enable_coriolis,
433        target_distance_m,
434        enable_wind_shear: false,  // Default for test function
435        wind_shear_model: "none".to_string(),
436        shooter_altitude_m: 0.0,
437        is_twist_right: true,  // Default for test function
438        custom_drag_table: None,  // No CDM for test function
439    };
440
441    let trajectory = integrate_trajectory(
442        initial_state,
443        t_span,
444        params,
445        &method,
446        tolerance,
447        max_step,
448    );
449
450    // Convert to Python-friendly format
451    trajectory.into_iter().map(|(t, state)| {
452        let mut point = HashMap::new();
453        point.insert("t".to_string(), t);
454        point.insert("x".to_string(), state[0]);
455        point.insert("y".to_string(), state[1]);
456        point.insert("z".to_string(), state[2]);
457        point.insert("vx".to_string(), state[3]);
458        point.insert("vy".to_string(), state[4]);
459        point.insert("vz".to_string(), state[5]);
460        point
461    }).collect()
462}