Skip to main content

ballistics_engine/
trajectory_solver.rs

1use crate::constants::{FPS_TO_MPS, GRAINS_TO_KG, MPS_TO_FPS};
2use crate::InternalBallisticInputs;
3use nalgebra::Vector3;
4
5// Constants for unit conversions
6const YARDS_TO_METERS: f64 = 0.9144;
7const JOULES_TO_FTLBS: f64 = 0.737562149;
8const METERS_TO_INCHES: f64 = 39.3701;
9
10/// Initial conditions for trajectory solving
11#[derive(Debug, Clone)]
12pub struct InitialConditions {
13    pub mass_kg: f64,
14    pub muzzle_velocity_mps: f64,
15    pub target_distance_los_m: f64,
16    pub muzzle_angle_rad: f64,
17    pub muzzle_energy_j: f64,
18    pub muzzle_energy_ftlbs: f64,
19    pub target_horizontal_dist_m: f64,
20    pub target_vertical_height_m: f64,
21    pub initial_state: [f64; 6], // [x, y, z, vx, vy, vz]
22    pub t_span: (f64, f64),
23    pub omega_vector: Option<Vector3<f64>>,
24    pub stability_coefficient: f64,
25    pub atmo_params: (f64, f64, f64, f64), // altitude, temp_c, pressure_hpa, density_ratio
26    pub air_density: f64,
27    pub speed_of_sound: f64,
28}
29
30/// Result of trajectory post-processing
31#[derive(Debug, Clone)]
32pub struct TrajectoryResult {
33    pub muzzle_energy_j: f64,
34    pub muzzle_energy_ftlbs: f64,
35    pub target_distance_los_m: f64,
36    pub target_distance_horiz_m: f64,
37    pub target_vertical_height_m: f64,
38    pub time_of_flight_s: f64,
39    pub drop_m: f64,
40    pub drop_in: f64,
41    pub wind_drift_m: f64,
42    pub wind_drift_in: f64,
43    pub max_ord_m: f64,
44    pub max_ord_in: f64,
45    pub max_ord_dist_horiz_m: f64,
46    pub final_vel_mps: f64,
47    pub final_vel_fps: f64,
48    pub final_energy_j: f64,
49    pub final_energy_ftlbs: f64,
50    pub air_density_kg_m3: f64,
51    pub speed_of_sound_mps: f64,
52    pub barrel_angle_rad: f64,
53}
54
55/// Prepare initial conditions for trajectory solving
56pub fn prepare_initial_conditions(
57    inputs: &InternalBallisticInputs,
58    zero_angle_rad: f64,
59    atmo_params: (f64, f64, f64, f64),
60    air_density: f64,
61    speed_of_sound: f64,
62    stability_coefficient: f64,
63) -> InitialConditions {
64    let mass_kg = inputs.bullet_mass * GRAINS_TO_KG;
65
66    // Adjust muzzle velocity (basic implementation - could be enhanced)
67    let mv_mps = inputs.muzzle_velocity * FPS_TO_MPS;
68
69    // Calculate target coordinates
70    let target_dist_m_los = inputs.target_distance * YARDS_TO_METERS;
71    let target_horizontal_dist_m = target_dist_m_los;
72    let target_vertical_height_m = 0.0; // Simplified for now
73
74    // Calculate muzzle angle
75    let muzzle_angle_rad = zero_angle_rad + inputs.muzzle_angle.to_radians();
76
77    // Calculate energies
78    let muzzle_energy_j = 0.5 * mass_kg * mv_mps * mv_mps;
79    let muzzle_energy_ftlbs = muzzle_energy_j * JOULES_TO_FTLBS;
80
81    // Set up initial velocity vector
82    let initial_vel = Vector3::new(
83        mv_mps * muzzle_angle_rad.cos(),
84        mv_mps * muzzle_angle_rad.sin(),
85        0.0,
86    );
87
88    // Initial state: [x, y, z, vx, vy, vz]
89    let initial_state = [0.0, 0.0, 0.0, initial_vel.x, initial_vel.y, initial_vel.z];
90
91    // Estimate maximum time
92    let initial_vx = initial_vel.x;
93    let max_time = if initial_vx > 1e-6 && target_horizontal_dist_m > 0.0 {
94        let est_min = target_horizontal_dist_m / initial_vx;
95        (est_min * 3.0).max(10.0)
96    } else {
97        10.0
98    };
99    let t_span = (0.0, max_time);
100
101    // Omega vector for Coriolis accounting for shot azimuth
102    let omega_vector = if inputs.enable_advanced_effects {
103        // Project Earth's rotation vector into the shooter's local frame.
104        // azimuth_angle: 0 = North, pi/2 = East
105        let latitude_rad = inputs.latitude.unwrap_or(0.0).to_radians();
106        let azimuth = inputs.azimuth_angle; // already in radians
107        let earth_rotation_rate = 7.2921159e-5; // rad/s
108        Some(Vector3::new(
109            earth_rotation_rate * latitude_rad.cos() * azimuth.sin(),
110            earth_rotation_rate * latitude_rad.sin(),
111            earth_rotation_rate * latitude_rad.cos() * azimuth.cos(),
112        ))
113    } else {
114        None
115    };
116
117    InitialConditions {
118        mass_kg,
119        muzzle_velocity_mps: mv_mps,
120        target_distance_los_m: target_dist_m_los,
121        muzzle_angle_rad,
122        muzzle_energy_j,
123        muzzle_energy_ftlbs,
124        target_horizontal_dist_m,
125        target_vertical_height_m,
126        initial_state,
127        t_span,
128        omega_vector,
129        stability_coefficient,
130        atmo_params,
131        air_density,
132        speed_of_sound,
133    }
134}
135
136/// Find trajectory apex using Brent's method root finding
137pub fn find_trajectory_apex(
138    trajectory_points: &[(f64, [f64; 6])], // (time, state) pairs
139    target_horizontal_dist_m: f64,
140) -> (f64, f64) {
141    // (max_ordinate_m, max_ordinate_x_m)
142    let mut max_ord_m = 0.0;
143    let mut max_ord_x_m = 0.0;
144
145    // Find the highest point that occurs before the target
146    for &(_, state) in trajectory_points {
147        let x = state[0];
148        let y = state[1];
149
150        if x <= target_horizontal_dist_m + 1e-6 && y > max_ord_m {
151            max_ord_m = y;
152            max_ord_x_m = x;
153        }
154    }
155
156    (max_ord_m, max_ord_x_m)
157}
158
159/// Brent's method for root finding (simplified implementation)
160pub fn brent_root_find<F>(
161    f: F,
162    mut a: f64,
163    mut b: f64,
164    tolerance: f64,
165    max_iterations: usize,
166) -> Result<f64, String>
167where
168    F: Fn(f64) -> f64,
169{
170    let mut fa = f(a);
171    let mut fb = f(b);
172
173    // Ensure the root is bracketed
174    if fa * fb > 0.0 {
175        return Err("Root not bracketed".to_string());
176    }
177
178    // Ensure |f(a)| >= |f(b)|
179    if fa.abs() < fb.abs() {
180        std::mem::swap(&mut a, &mut b);
181        std::mem::swap(&mut fa, &mut fb);
182    }
183
184    let mut c = a;
185    let mut fc = fa;
186    let mut d = b - a;
187    let mut e = d;
188
189    for _ in 0..max_iterations {
190        if fb.abs() < tolerance {
191            return Ok(b);
192        }
193
194        if fa.abs() < fb.abs() {
195            a = b;
196            b = c;
197            c = a;
198            fa = fb;
199            fb = fc;
200            fc = fa;
201        }
202
203        let tolerance_scaled = 2.0 * f64::EPSILON * b.abs() + 0.5 * tolerance;
204        let m = 0.5 * (c - b);
205
206        if m.abs() <= tolerance_scaled {
207            return Ok(b);
208        }
209
210        if e.abs() >= tolerance_scaled && fc.abs() > fb.abs() {
211            let s = fb / fc;
212            let mut p;
213            let mut q;
214
215            if (a - c).abs() < f64::EPSILON {
216                // Linear interpolation
217                p = 2.0 * m * s;
218                q = 1.0 - s;
219            } else {
220                // Inverse quadratic interpolation
221                q = fc / fa;
222                let r = fb / fa;
223                p = s * (2.0 * m * q * (q - r) - (b - a) * (r - 1.0));
224                q = (q - 1.0) * (r - 1.0) * (s - 1.0);
225            }
226
227            if p > 0.0 {
228                q = -q;
229            } else {
230                p = -p;
231            }
232
233            let s = e;
234            e = d;
235
236            if 2.0 * p < 3.0 * m * q - (tolerance_scaled * q).abs() && p < (0.5 * s * q).abs() {
237                d = p / q;
238            } else {
239                d = m;
240                e = d;
241            }
242        } else {
243            d = m;
244            e = d;
245        }
246
247        a = b;
248        fa = fb;
249
250        if d.abs() > tolerance_scaled {
251            b += d;
252        } else if m > 0.0 {
253            b += tolerance_scaled;
254        } else {
255            b -= tolerance_scaled;
256        }
257
258        fb = f(b);
259
260        if (fc * fb) > 0.0 {
261            c = a;
262            fc = fa;
263            e = b - a;
264            d = e;
265        }
266    }
267
268    Err("Maximum iterations exceeded".to_string())
269}
270
271/// Post-process trajectory solution to create final results
272pub fn post_process_trajectory(
273    trajectory_points: &[(f64, [f64; 6])], // (time, state) pairs
274    initial_conditions: &InitialConditions,
275    inputs: &InternalBallisticInputs,
276    target_hit_time: Option<f64>,
277    ground_hit_time: Option<f64>,
278) -> Result<TrajectoryResult, String> {
279    // Determine final time and state
280    let (final_time, final_state) = if let Some(hit_time) = target_hit_time {
281        // Interpolate state at target hit time
282        let final_state = interpolate_trajectory_state(trajectory_points, hit_time)?;
283        (hit_time, final_state)
284    } else if ground_hit_time.is_some() {
285        return Err("Projectile impacted ground before reaching target".to_string());
286    } else if let Some((time, state)) = trajectory_points.last() {
287        (*time, *state)
288    } else {
289        return Err("No trajectory data available".to_string());
290    };
291
292    // Check if target was reached
293    let distance_err = final_state[0] - initial_conditions.target_horizontal_dist_m;
294    if distance_err.abs() >= 1e-3
295        && final_state[0] < initial_conditions.target_horizontal_dist_m - 1e-3
296    {
297        return Err(format!(
298            "Target horizontal distance ({:.2}m) not reached. Max distance: {:.2}m",
299            initial_conditions.target_horizontal_dist_m, final_state[0]
300        ));
301    }
302
303    // Find trajectory apex
304    let (max_ord_m, max_ord_x_m) = find_trajectory_apex(
305        trajectory_points,
306        initial_conditions.target_horizontal_dist_m,
307    );
308
309    // Calculate final results
310    let final_y_m = final_state[1];
311    let final_z_m = final_state[2];
312    let drop_m = initial_conditions.target_vertical_height_m - final_y_m;
313
314    // Calculate wind drift including spin drift
315    let wind_drift_m = final_z_m;
316    if inputs.enable_advanced_effects {
317        // Add spin drift using existing function
318        // TODO: Re-enable when stability module is available
319        // use crate::stability::compute_spin_drift;
320        // wind_drift_m += compute_spin_drift(
321        //     final_time,
322        //     initial_conditions.stability_coefficient,
323        //     inputs.twist_rate,
324        //     inputs.is_twist_right,
325        // );
326    }
327
328    // Calculate final velocity and energy
329    let final_vel = Vector3::new(final_state[3], final_state[4], final_state[5]);
330    let final_vel_mag = final_vel.norm();
331    let final_energy_j = 0.5 * initial_conditions.mass_kg * final_vel_mag * final_vel_mag;
332
333    Ok(TrajectoryResult {
334        muzzle_energy_j: initial_conditions.muzzle_energy_j,
335        muzzle_energy_ftlbs: initial_conditions.muzzle_energy_ftlbs,
336        target_distance_los_m: initial_conditions.target_distance_los_m,
337        target_distance_horiz_m: initial_conditions.target_horizontal_dist_m,
338        target_vertical_height_m: initial_conditions.target_vertical_height_m,
339        time_of_flight_s: final_time,
340        drop_m,
341        drop_in: drop_m * METERS_TO_INCHES,
342        wind_drift_m,
343        wind_drift_in: wind_drift_m * METERS_TO_INCHES,
344        max_ord_m,
345        max_ord_in: max_ord_m * METERS_TO_INCHES,
346        max_ord_dist_horiz_m: max_ord_x_m,
347        final_vel_mps: final_vel_mag,
348        final_vel_fps: final_vel_mag * MPS_TO_FPS,
349        final_energy_j,
350        final_energy_ftlbs: final_energy_j * JOULES_TO_FTLBS,
351        air_density_kg_m3: initial_conditions.air_density,
352        speed_of_sound_mps: initial_conditions.speed_of_sound,
353        barrel_angle_rad: initial_conditions.muzzle_angle_rad,
354    })
355}
356
357/// Interpolate trajectory state at a specific time
358fn interpolate_trajectory_state(
359    trajectory_points: &[(f64, [f64; 6])],
360    target_time: f64,
361) -> Result<[f64; 6], String> {
362    if trajectory_points.is_empty() {
363        return Err("No trajectory points available".to_string());
364    }
365
366    // Find bounding points
367    let mut lower_idx = 0;
368    let mut upper_idx = trajectory_points.len() - 1;
369
370    for (i, &(time, _)) in trajectory_points.iter().enumerate() {
371        if time <= target_time {
372            lower_idx = i;
373        }
374        if time >= target_time && upper_idx == trajectory_points.len() - 1 {
375            upper_idx = i;
376            break;
377        }
378    }
379
380    if lower_idx == upper_idx {
381        return Ok(trajectory_points[lower_idx].1);
382    }
383
384    // Linear interpolation
385    let (t1, state1) = trajectory_points[lower_idx];
386    let (t2, state2) = trajectory_points[upper_idx];
387
388    if (t2 - t1).abs() < f64::EPSILON {
389        return Ok(state1);
390    }
391
392    let alpha = (target_time - t1) / (t2 - t1);
393    let mut interpolated_state = [0.0; 6];
394
395    for i in 0..6 {
396        interpolated_state[i] = state1[i] + alpha * (state2[i] - state1[i]);
397    }
398
399    Ok(interpolated_state)
400}
401
402#[cfg(test)]
403mod tests {
404    use super::*;
405
406    #[test]
407    fn test_brent_root_find() {
408        // Test with simple quadratic: x^2 - 4 = 0, root at x = 2
409        let f = |x: f64| x * x - 4.0;
410        let root = brent_root_find(f, 1.0, 3.0, 1e-6, 100).unwrap();
411        assert!((root - 2.0).abs() < 1e-6);
412    }
413
414    #[test]
415    fn test_interpolate_trajectory_state() {
416        let points = vec![
417            (0.0, [0.0, 0.0, 0.0, 100.0, 50.0, 0.0]),
418            (1.0, [100.0, 45.0, 0.0, 99.0, 40.0, 0.0]),
419            (2.0, [200.0, 80.0, 0.0, 98.0, 30.0, 0.0]),
420        ];
421
422        let result = interpolate_trajectory_state(&points, 1.5).unwrap();
423
424        // Should be halfway between points at t=1.0 and t=2.0
425        assert!((result[0] - 150.0).abs() < 1e-10); // x position
426        assert!((result[1] - 62.5).abs() < 1e-10); // y position
427        assert!((result[3] - 98.5).abs() < 1e-10); // vx velocity
428    }
429
430    #[test]
431    fn test_find_trajectory_apex() {
432        let points = vec![
433            (0.0, [0.0, 0.0, 0.0, 100.0, 50.0, 0.0]),
434            (1.0, [100.0, 45.0, 0.0, 99.0, 40.0, 0.0]),
435            (2.0, [200.0, 80.0, 0.0, 98.0, 30.0, 0.0]), // Peak here
436            (3.0, [300.0, 75.0, 0.0, 97.0, 20.0, 0.0]),
437            (4.0, [400.0, 60.0, 0.0, 96.0, 10.0, 0.0]),
438        ];
439
440        let (max_ord, max_ord_x) = find_trajectory_apex(&points, 500.0);
441
442        assert!((max_ord - 80.0).abs() < 1e-10);
443        assert!((max_ord_x - 200.0).abs() < 1e-10);
444    }
445}