ballistics_engine/
trajectory_solver.rs

1use crate::constants::{FPS_TO_MPS, GRAINS_TO_KG, MPS_TO_FPS};
2use crate::InternalBallisticInputs;
3use nalgebra::Vector3;
4
5// Constants for unit conversions
6const YARDS_TO_METERS: f64 = 0.9144;
7const JOULES_TO_FTLBS: f64 = 0.737562149;
8const METERS_TO_INCHES: f64 = 39.3701;
9
10/// Initial conditions for trajectory solving
11#[derive(Debug, Clone)]
12pub struct InitialConditions {
13    pub mass_kg: f64,
14    pub muzzle_velocity_mps: f64,
15    pub target_distance_los_m: f64,
16    pub muzzle_angle_rad: f64,
17    pub muzzle_energy_j: f64,
18    pub muzzle_energy_ftlbs: f64,
19    pub target_horizontal_dist_m: f64,
20    pub target_vertical_height_m: f64,
21    pub initial_state: [f64; 6], // [x, y, z, vx, vy, vz]
22    pub t_span: (f64, f64),
23    pub omega_vector: Option<Vector3<f64>>,
24    pub stability_coefficient: f64,
25    pub atmo_params: (f64, f64, f64, f64), // altitude, temp_c, pressure_hpa, density_ratio
26    pub air_density: f64,
27    pub speed_of_sound: f64,
28}
29
30/// Result of trajectory post-processing
31#[derive(Debug, Clone)]
32pub struct TrajectoryResult {
33    pub muzzle_energy_j: f64,
34    pub muzzle_energy_ftlbs: f64,
35    pub target_distance_los_m: f64,
36    pub target_distance_horiz_m: f64,
37    pub target_vertical_height_m: f64,
38    pub time_of_flight_s: f64,
39    pub drop_m: f64,
40    pub drop_in: f64,
41    pub wind_drift_m: f64,
42    pub wind_drift_in: f64,
43    pub max_ord_m: f64,
44    pub max_ord_in: f64,
45    pub max_ord_dist_horiz_m: f64,
46    pub final_vel_mps: f64,
47    pub final_vel_fps: f64,
48    pub final_energy_j: f64,
49    pub final_energy_ftlbs: f64,
50    pub air_density_kg_m3: f64,
51    pub speed_of_sound_mps: f64,
52    pub barrel_angle_rad: f64,
53}
54
55/// Prepare initial conditions for trajectory solving
56pub fn prepare_initial_conditions(
57    inputs: &InternalBallisticInputs,
58    zero_angle_rad: f64,
59    atmo_params: (f64, f64, f64, f64),
60    air_density: f64,
61    speed_of_sound: f64,
62    stability_coefficient: f64,
63) -> InitialConditions {
64    let mass_kg = inputs.bullet_mass * GRAINS_TO_KG;
65
66    // Adjust muzzle velocity (basic implementation - could be enhanced)
67    let mv_mps = inputs.muzzle_velocity * FPS_TO_MPS;
68
69    // Calculate target coordinates
70    let target_dist_m_los = inputs.target_distance * YARDS_TO_METERS;
71    let target_horizontal_dist_m = target_dist_m_los;
72    let target_vertical_height_m = 0.0; // Simplified for now
73
74    // Calculate muzzle angle
75    let muzzle_angle_rad = zero_angle_rad + inputs.muzzle_angle.to_radians();
76
77    // Calculate energies
78    let muzzle_energy_j = 0.5 * mass_kg * mv_mps * mv_mps;
79    let muzzle_energy_ftlbs = muzzle_energy_j * JOULES_TO_FTLBS;
80
81    // Set up initial velocity vector
82    let initial_vel = Vector3::new(
83        mv_mps * muzzle_angle_rad.cos(),
84        mv_mps * muzzle_angle_rad.sin(),
85        0.0,
86    );
87
88    // Initial state: [x, y, z, vx, vy, vz]
89    let initial_state = [0.0, 0.0, 0.0, initial_vel.x, initial_vel.y, initial_vel.z];
90
91    // Estimate maximum time
92    let initial_vx = initial_vel.x;
93    let max_time = if initial_vx > 1e-6 && target_horizontal_dist_m > 0.0 {
94        let est_min = target_horizontal_dist_m / initial_vx;
95        (est_min * 3.0).max(10.0)
96    } else {
97        10.0
98    };
99    let t_span = (0.0, max_time);
100
101    // Omega vector for Coriolis (simplified - would need more complex calculation)
102    let omega_vector = if inputs.enable_advanced_effects {
103        // Simplified Coriolis vector calculation
104        let latitude_rad = inputs.latitude.unwrap_or(0.0).to_radians();
105        let earth_rotation_rate = 7.2921159e-5; // rad/s
106        Some(Vector3::new(
107            0.0,
108            earth_rotation_rate * latitude_rad.cos(),
109            earth_rotation_rate * latitude_rad.sin(),
110        ))
111    } else {
112        None
113    };
114
115    InitialConditions {
116        mass_kg,
117        muzzle_velocity_mps: mv_mps,
118        target_distance_los_m: target_dist_m_los,
119        muzzle_angle_rad,
120        muzzle_energy_j,
121        muzzle_energy_ftlbs,
122        target_horizontal_dist_m,
123        target_vertical_height_m,
124        initial_state,
125        t_span,
126        omega_vector,
127        stability_coefficient,
128        atmo_params,
129        air_density,
130        speed_of_sound,
131    }
132}
133
134/// Find trajectory apex using Brent's method root finding
135pub fn find_trajectory_apex(
136    trajectory_points: &[(f64, [f64; 6])], // (time, state) pairs
137    target_horizontal_dist_m: f64,
138) -> (f64, f64) {
139    // (max_ordinate_m, max_ordinate_x_m)
140    let mut max_ord_m = 0.0;
141    let mut max_ord_x_m = 0.0;
142
143    // Find the highest point that occurs before the target
144    for &(_, state) in trajectory_points {
145        let x = state[0];
146        let y = state[1];
147
148        if x <= target_horizontal_dist_m + 1e-6 && y > max_ord_m {
149            max_ord_m = y;
150            max_ord_x_m = x;
151        }
152    }
153
154    (max_ord_m, max_ord_x_m)
155}
156
157/// Brent's method for root finding (simplified implementation)
158pub fn brent_root_find<F>(
159    f: F,
160    mut a: f64,
161    mut b: f64,
162    tolerance: f64,
163    max_iterations: usize,
164) -> Result<f64, String>
165where
166    F: Fn(f64) -> f64,
167{
168    let mut fa = f(a);
169    let mut fb = f(b);
170
171    // Ensure the root is bracketed
172    if fa * fb > 0.0 {
173        return Err("Root not bracketed".to_string());
174    }
175
176    // Ensure |f(a)| >= |f(b)|
177    if fa.abs() < fb.abs() {
178        std::mem::swap(&mut a, &mut b);
179        std::mem::swap(&mut fa, &mut fb);
180    }
181
182    let mut c = a;
183    let mut fc = fa;
184    let mut d = b - a;
185    let mut e = d;
186
187    for _ in 0..max_iterations {
188        if fb.abs() < tolerance {
189            return Ok(b);
190        }
191
192        if fa.abs() < fb.abs() {
193            a = b;
194            b = c;
195            c = a;
196            fa = fb;
197            fb = fc;
198            fc = fa;
199        }
200
201        let tolerance_scaled = 2.0 * f64::EPSILON * b.abs() + 0.5 * tolerance;
202        let m = 0.5 * (c - b);
203
204        if m.abs() <= tolerance_scaled {
205            return Ok(b);
206        }
207
208        if e.abs() >= tolerance_scaled && fc.abs() > fb.abs() {
209            let s = fb / fc;
210            let mut p;
211            let mut q;
212
213            if (a - c).abs() < f64::EPSILON {
214                // Linear interpolation
215                p = 2.0 * m * s;
216                q = 1.0 - s;
217            } else {
218                // Inverse quadratic interpolation
219                q = fc / fa;
220                let r = fb / fa;
221                p = s * (2.0 * m * q * (q - r) - (b - a) * (r - 1.0));
222                q = (q - 1.0) * (r - 1.0) * (s - 1.0);
223            }
224
225            if p > 0.0 {
226                q = -q;
227            } else {
228                p = -p;
229            }
230
231            let s = e;
232            e = d;
233
234            if 2.0 * p < 3.0 * m * q - (tolerance_scaled * q).abs() && p < (0.5 * s * q).abs() {
235                d = p / q;
236            } else {
237                d = m;
238                e = d;
239            }
240        } else {
241            d = m;
242            e = d;
243        }
244
245        a = b;
246        fa = fb;
247
248        if d.abs() > tolerance_scaled {
249            b += d;
250        } else if m > 0.0 {
251            b += tolerance_scaled;
252        } else {
253            b -= tolerance_scaled;
254        }
255
256        fb = f(b);
257
258        if (fc * fb) > 0.0 {
259            c = a;
260            fc = fa;
261            e = b - a;
262            d = e;
263        }
264    }
265
266    Err("Maximum iterations exceeded".to_string())
267}
268
269/// Post-process trajectory solution to create final results
270pub fn post_process_trajectory(
271    trajectory_points: &[(f64, [f64; 6])], // (time, state) pairs
272    initial_conditions: &InitialConditions,
273    inputs: &InternalBallisticInputs,
274    target_hit_time: Option<f64>,
275    ground_hit_time: Option<f64>,
276) -> Result<TrajectoryResult, String> {
277    // Determine final time and state
278    let (final_time, final_state) = if let Some(hit_time) = target_hit_time {
279        // Interpolate state at target hit time
280        let final_state = interpolate_trajectory_state(trajectory_points, hit_time)?;
281        (hit_time, final_state)
282    } else if ground_hit_time.is_some() {
283        return Err("Projectile impacted ground before reaching target".to_string());
284    } else if let Some((time, state)) = trajectory_points.last() {
285        (*time, *state)
286    } else {
287        return Err("No trajectory data available".to_string());
288    };
289
290    // Check if target was reached
291    let distance_err = final_state[0] - initial_conditions.target_horizontal_dist_m;
292    if distance_err.abs() >= 1e-3
293        && final_state[0] < initial_conditions.target_horizontal_dist_m - 1e-3
294    {
295        return Err(format!(
296            "Target horizontal distance ({:.2}m) not reached. Max distance: {:.2}m",
297            initial_conditions.target_horizontal_dist_m, final_state[0]
298        ));
299    }
300
301    // Find trajectory apex
302    let (max_ord_m, max_ord_x_m) = find_trajectory_apex(
303        trajectory_points,
304        initial_conditions.target_horizontal_dist_m,
305    );
306
307    // Calculate final results
308    let final_y_m = final_state[1];
309    let final_z_m = final_state[2];
310    let drop_m = initial_conditions.target_vertical_height_m - final_y_m;
311
312    // Calculate wind drift including spin drift
313    let wind_drift_m = final_z_m;
314    if inputs.enable_advanced_effects {
315        // Add spin drift using existing function
316        // TODO: Re-enable when stability module is available
317        // use crate::stability::compute_spin_drift;
318        // wind_drift_m += compute_spin_drift(
319        //     final_time,
320        //     initial_conditions.stability_coefficient,
321        //     inputs.twist_rate,
322        //     inputs.is_twist_right,
323        // );
324    }
325
326    // Calculate final velocity and energy
327    let final_vel = Vector3::new(final_state[3], final_state[4], final_state[5]);
328    let final_vel_mag = final_vel.norm();
329    let final_energy_j = 0.5 * initial_conditions.mass_kg * final_vel_mag * final_vel_mag;
330
331    Ok(TrajectoryResult {
332        muzzle_energy_j: initial_conditions.muzzle_energy_j,
333        muzzle_energy_ftlbs: initial_conditions.muzzle_energy_ftlbs,
334        target_distance_los_m: initial_conditions.target_distance_los_m,
335        target_distance_horiz_m: initial_conditions.target_horizontal_dist_m,
336        target_vertical_height_m: initial_conditions.target_vertical_height_m,
337        time_of_flight_s: final_time,
338        drop_m,
339        drop_in: drop_m * METERS_TO_INCHES,
340        wind_drift_m,
341        wind_drift_in: wind_drift_m * METERS_TO_INCHES,
342        max_ord_m,
343        max_ord_in: max_ord_m * METERS_TO_INCHES,
344        max_ord_dist_horiz_m: max_ord_x_m,
345        final_vel_mps: final_vel_mag,
346        final_vel_fps: final_vel_mag * MPS_TO_FPS,
347        final_energy_j,
348        final_energy_ftlbs: final_energy_j * JOULES_TO_FTLBS,
349        air_density_kg_m3: initial_conditions.air_density,
350        speed_of_sound_mps: initial_conditions.speed_of_sound,
351        barrel_angle_rad: initial_conditions.muzzle_angle_rad,
352    })
353}
354
355/// Interpolate trajectory state at a specific time
356fn interpolate_trajectory_state(
357    trajectory_points: &[(f64, [f64; 6])],
358    target_time: f64,
359) -> Result<[f64; 6], String> {
360    if trajectory_points.is_empty() {
361        return Err("No trajectory points available".to_string());
362    }
363
364    // Find bounding points
365    let mut lower_idx = 0;
366    let mut upper_idx = trajectory_points.len() - 1;
367
368    for (i, &(time, _)) in trajectory_points.iter().enumerate() {
369        if time <= target_time {
370            lower_idx = i;
371        }
372        if time >= target_time && upper_idx == trajectory_points.len() - 1 {
373            upper_idx = i;
374            break;
375        }
376    }
377
378    if lower_idx == upper_idx {
379        return Ok(trajectory_points[lower_idx].1);
380    }
381
382    // Linear interpolation
383    let (t1, state1) = trajectory_points[lower_idx];
384    let (t2, state2) = trajectory_points[upper_idx];
385
386    if (t2 - t1).abs() < f64::EPSILON {
387        return Ok(state1);
388    }
389
390    let alpha = (target_time - t1) / (t2 - t1);
391    let mut interpolated_state = [0.0; 6];
392
393    for i in 0..6 {
394        interpolated_state[i] = state1[i] + alpha * (state2[i] - state1[i]);
395    }
396
397    Ok(interpolated_state)
398}
399
400#[cfg(test)]
401mod tests {
402    use super::*;
403
404    #[test]
405    fn test_brent_root_find() {
406        // Test with simple quadratic: x^2 - 4 = 0, root at x = 2
407        let f = |x: f64| x * x - 4.0;
408        let root = brent_root_find(f, 1.0, 3.0, 1e-6, 100).unwrap();
409        assert!((root - 2.0).abs() < 1e-6);
410    }
411
412    #[test]
413    fn test_interpolate_trajectory_state() {
414        let points = vec![
415            (0.0, [0.0, 0.0, 0.0, 100.0, 50.0, 0.0]),
416            (1.0, [100.0, 45.0, 0.0, 99.0, 40.0, 0.0]),
417            (2.0, [200.0, 80.0, 0.0, 98.0, 30.0, 0.0]),
418        ];
419
420        let result = interpolate_trajectory_state(&points, 1.5).unwrap();
421
422        // Should be halfway between points at t=1.0 and t=2.0
423        assert!((result[0] - 150.0).abs() < 1e-10); // x position
424        assert!((result[1] - 62.5).abs() < 1e-10); // y position
425        assert!((result[3] - 98.5).abs() < 1e-10); // vx velocity
426    }
427
428    #[test]
429    fn test_find_trajectory_apex() {
430        let points = vec![
431            (0.0, [0.0, 0.0, 0.0, 100.0, 50.0, 0.0]),
432            (1.0, [100.0, 45.0, 0.0, 99.0, 40.0, 0.0]),
433            (2.0, [200.0, 80.0, 0.0, 98.0, 30.0, 0.0]), // Peak here
434            (3.0, [300.0, 75.0, 0.0, 97.0, 20.0, 0.0]),
435            (4.0, [400.0, 60.0, 0.0, 96.0, 10.0, 0.0]),
436        ];
437
438        let (max_ord, max_ord_x) = find_trajectory_apex(&points, 500.0);
439
440        assert!((max_ord - 80.0).abs() < 1e-10);
441        assert!((max_ord_x - 200.0).abs() < 1e-10);
442    }
443}