ballistics_engine/
trajectory_solver.rs

1use crate::InternalBallisticInputs;
2use crate::constants::{FPS_TO_MPS, MPS_TO_FPS, GRAINS_TO_KG};
3use nalgebra::Vector3;
4
5// Constants for unit conversions
6const YARDS_TO_METERS: f64 = 0.9144;
7const JOULES_TO_FTLBS: f64 = 0.737562149;
8const METERS_TO_INCHES: f64 = 39.3701;
9
10/// Initial conditions for trajectory solving
11#[derive(Debug, Clone)]
12pub struct InitialConditions {
13    pub mass_kg: f64,
14    pub muzzle_velocity_mps: f64,
15    pub target_distance_los_m: f64,
16    pub muzzle_angle_rad: f64,
17    pub muzzle_energy_j: f64,
18    pub muzzle_energy_ftlbs: f64,
19    pub target_horizontal_dist_m: f64,
20    pub target_vertical_height_m: f64,
21    pub initial_state: [f64; 6],  // [x, y, z, vx, vy, vz]
22    pub t_span: (f64, f64),
23    pub omega_vector: Option<Vector3<f64>>,
24    pub stability_coefficient: f64,
25    pub atmo_params: (f64, f64, f64, f64),  // altitude, temp_c, pressure_hpa, density_ratio
26    pub air_density: f64,
27    pub speed_of_sound: f64,
28}
29
30/// Result of trajectory post-processing
31#[derive(Debug, Clone)]
32pub struct TrajectoryResult {
33    pub muzzle_energy_j: f64,
34    pub muzzle_energy_ftlbs: f64,
35    pub target_distance_los_m: f64,
36    pub target_distance_horiz_m: f64,
37    pub target_vertical_height_m: f64,
38    pub time_of_flight_s: f64,
39    pub drop_m: f64,
40    pub drop_in: f64,
41    pub wind_drift_m: f64,
42    pub wind_drift_in: f64,
43    pub max_ord_m: f64,
44    pub max_ord_in: f64,
45    pub max_ord_dist_horiz_m: f64,
46    pub final_vel_mps: f64,
47    pub final_vel_fps: f64,
48    pub final_energy_j: f64,
49    pub final_energy_ftlbs: f64,
50    pub air_density_kg_m3: f64,
51    pub speed_of_sound_mps: f64,
52    pub barrel_angle_rad: f64,
53}
54
55/// Prepare initial conditions for trajectory solving
56pub fn prepare_initial_conditions(
57    inputs: &InternalBallisticInputs,
58    zero_angle_rad: f64,
59    atmo_params: (f64, f64, f64, f64),
60    air_density: f64,
61    speed_of_sound: f64,
62    stability_coefficient: f64,
63) -> InitialConditions {
64    let mass_kg = inputs.bullet_mass * GRAINS_TO_KG;
65    
66    // Adjust muzzle velocity (basic implementation - could be enhanced)
67    let mv_mps = inputs.muzzle_velocity * FPS_TO_MPS;
68    
69    // Calculate target coordinates
70    let target_dist_m_los = inputs.target_distance * YARDS_TO_METERS;
71    let target_horizontal_dist_m = target_dist_m_los;
72    let target_vertical_height_m = 0.0; // Simplified for now
73    
74    // Calculate muzzle angle
75    let muzzle_angle_rad = zero_angle_rad + inputs.muzzle_angle.to_radians();
76    
77    // Calculate energies
78    let muzzle_energy_j = 0.5 * mass_kg * mv_mps * mv_mps;
79    let muzzle_energy_ftlbs = muzzle_energy_j * JOULES_TO_FTLBS;
80    
81    // Set up initial velocity vector
82    let initial_vel = Vector3::new(
83        mv_mps * muzzle_angle_rad.cos(),
84        mv_mps * muzzle_angle_rad.sin(),
85        0.0,
86    );
87    
88    // Initial state: [x, y, z, vx, vy, vz]
89    let initial_state = [
90        0.0, 0.0, 0.0,
91        initial_vel.x, initial_vel.y, initial_vel.z,
92    ];
93    
94    // Estimate maximum time
95    let initial_vx = initial_vel.x;
96    let max_time = if initial_vx > 1e-6 && target_horizontal_dist_m > 0.0 {
97        let est_min = target_horizontal_dist_m / initial_vx;
98        (est_min * 3.0).max(10.0)
99    } else {
100        10.0
101    };
102    let t_span = (0.0, max_time);
103    
104    // Omega vector for Coriolis (simplified - would need more complex calculation)
105    let omega_vector = if inputs.enable_advanced_effects {
106        // Simplified Coriolis vector calculation
107        let latitude_rad = inputs.latitude.unwrap_or(0.0).to_radians();
108        let earth_rotation_rate = 7.2921159e-5; // rad/s
109        Some(Vector3::new(
110            0.0,
111            earth_rotation_rate * latitude_rad.cos(),
112            earth_rotation_rate * latitude_rad.sin(),
113        ))
114    } else {
115        None
116    };
117    
118    InitialConditions {
119        mass_kg,
120        muzzle_velocity_mps: mv_mps,
121        target_distance_los_m: target_dist_m_los,
122        muzzle_angle_rad,
123        muzzle_energy_j,
124        muzzle_energy_ftlbs,
125        target_horizontal_dist_m,
126        target_vertical_height_m,
127        initial_state,
128        t_span,
129        omega_vector,
130        stability_coefficient,
131        atmo_params,
132        air_density,
133        speed_of_sound,
134    }
135}
136
137/// Find trajectory apex using Brent's method root finding
138pub fn find_trajectory_apex(
139    trajectory_points: &[(f64, [f64; 6])], // (time, state) pairs
140    target_horizontal_dist_m: f64,
141) -> (f64, f64) { // (max_ordinate_m, max_ordinate_x_m)
142    let mut max_ord_m = 0.0;
143    let mut max_ord_x_m = 0.0;
144    
145    // Find the highest point that occurs before the target
146    for &(_, state) in trajectory_points {
147        let x = state[0];
148        let y = state[1];
149        
150        if x <= target_horizontal_dist_m + 1e-6 && y > max_ord_m {
151            max_ord_m = y;
152            max_ord_x_m = x;
153        }
154    }
155    
156    (max_ord_m, max_ord_x_m)
157}
158
159/// Brent's method for root finding (simplified implementation)
160pub fn brent_root_find<F>(
161    f: F,
162    mut a: f64,
163    mut b: f64,
164    tolerance: f64,
165    max_iterations: usize,
166) -> Result<f64, String>
167where
168    F: Fn(f64) -> f64,
169{
170    let mut fa = f(a);
171    let mut fb = f(b);
172    
173    // Ensure the root is bracketed
174    if fa * fb > 0.0 {
175        return Err("Root not bracketed".to_string());
176    }
177    
178    // Ensure |f(a)| >= |f(b)|
179    if fa.abs() < fb.abs() {
180        std::mem::swap(&mut a, &mut b);
181        std::mem::swap(&mut fa, &mut fb);
182    }
183    
184    let mut c = a;
185    let mut fc = fa;
186    let mut d = b - a;
187    let mut e = d;
188    
189    for _ in 0..max_iterations {
190        if fb.abs() < tolerance {
191            return Ok(b);
192        }
193        
194        if fa.abs() < fb.abs() {
195            a = b;
196            b = c;
197            c = a;
198            fa = fb;
199            fb = fc;
200            fc = fa;
201        }
202        
203        let tolerance_scaled = 2.0 * f64::EPSILON * b.abs() + 0.5 * tolerance;
204        let m = 0.5 * (c - b);
205        
206        if m.abs() <= tolerance_scaled {
207            return Ok(b);
208        }
209        
210        if e.abs() >= tolerance_scaled && fc.abs() > fb.abs() {
211            let s = fb / fc;
212            let mut p;
213            let mut q;
214            
215            if (a - c).abs() < f64::EPSILON {
216                // Linear interpolation
217                p = 2.0 * m * s;
218                q = 1.0 - s;
219            } else {
220                // Inverse quadratic interpolation
221                q = fc / fa;
222                let r = fb / fa;
223                p = s * (2.0 * m * q * (q - r) - (b - a) * (r - 1.0));
224                q = (q - 1.0) * (r - 1.0) * (s - 1.0);
225            }
226            
227            if p > 0.0 {
228                q = -q;
229            } else {
230                p = -p;
231            }
232            
233            let s = e;
234            e = d;
235            
236            if 2.0 * p < 3.0 * m * q - (tolerance_scaled * q).abs() 
237                && p < (0.5 * s * q).abs() {
238                d = p / q;
239            } else {
240                d = m;
241                e = d;
242            }
243        } else {
244            d = m;
245            e = d;
246        }
247        
248        a = b;
249        fa = fb;
250        
251        if d.abs() > tolerance_scaled {
252            b += d;
253        } else if m > 0.0 {
254            b += tolerance_scaled;
255        } else {
256            b -= tolerance_scaled;
257        }
258        
259        fb = f(b);
260        
261        if (fc * fb) > 0.0 {
262            c = a;
263            fc = fa;
264            e = b - a;
265            d = e;
266        }
267    }
268    
269    Err("Maximum iterations exceeded".to_string())
270}
271
272/// Post-process trajectory solution to create final results
273pub fn post_process_trajectory(
274    trajectory_points: &[(f64, [f64; 6])], // (time, state) pairs
275    initial_conditions: &InitialConditions,
276    inputs: &InternalBallisticInputs,
277    target_hit_time: Option<f64>,
278    ground_hit_time: Option<f64>,
279) -> Result<TrajectoryResult, String> {
280    // Determine final time and state
281    let (final_time, final_state) = if let Some(hit_time) = target_hit_time {
282        // Interpolate state at target hit time
283        let final_state = interpolate_trajectory_state(trajectory_points, hit_time)?;
284        (hit_time, final_state)
285    } else if ground_hit_time.is_some() {
286        return Err("Projectile impacted ground before reaching target".to_string());
287    } else if let Some((time, state)) = trajectory_points.last() {
288        (*time, *state)
289    } else {
290        return Err("No trajectory data available".to_string());
291    };
292    
293    // Check if target was reached
294    let distance_err = final_state[0] - initial_conditions.target_horizontal_dist_m;
295    if distance_err.abs() >= 1e-3 && final_state[0] < initial_conditions.target_horizontal_dist_m - 1e-3 {
296        return Err(format!(
297            "Target horizontal distance ({:.2}m) not reached. Max distance: {:.2}m",
298            initial_conditions.target_horizontal_dist_m, final_state[0]
299        ));
300    }
301    
302    // Find trajectory apex
303    let (max_ord_m, max_ord_x_m) = find_trajectory_apex(
304        trajectory_points,
305        initial_conditions.target_horizontal_dist_m,
306    );
307    
308    // Calculate final results
309    let final_y_m = final_state[1];
310    let final_z_m = final_state[2];
311    let drop_m = initial_conditions.target_vertical_height_m - final_y_m;
312    
313    // Calculate wind drift including spin drift
314    let wind_drift_m = final_z_m;
315    if inputs.enable_advanced_effects {
316        // Add spin drift using existing function
317        // TODO: Re-enable when stability module is available
318        // use crate::stability::compute_spin_drift;
319        // wind_drift_m += compute_spin_drift(
320        //     final_time,
321        //     initial_conditions.stability_coefficient,
322        //     inputs.twist_rate,
323        //     inputs.is_twist_right,
324        // );
325    }
326    
327    // Calculate final velocity and energy
328    let final_vel = Vector3::new(final_state[3], final_state[4], final_state[5]);
329    let final_vel_mag = final_vel.norm();
330    let final_energy_j = 0.5 * initial_conditions.mass_kg * final_vel_mag * final_vel_mag;
331    
332    Ok(TrajectoryResult {
333        muzzle_energy_j: initial_conditions.muzzle_energy_j,
334        muzzle_energy_ftlbs: initial_conditions.muzzle_energy_ftlbs,
335        target_distance_los_m: initial_conditions.target_distance_los_m,
336        target_distance_horiz_m: initial_conditions.target_horizontal_dist_m,
337        target_vertical_height_m: initial_conditions.target_vertical_height_m,
338        time_of_flight_s: final_time,
339        drop_m,
340        drop_in: drop_m * METERS_TO_INCHES,
341        wind_drift_m,
342        wind_drift_in: wind_drift_m * METERS_TO_INCHES,
343        max_ord_m,
344        max_ord_in: max_ord_m * METERS_TO_INCHES,
345        max_ord_dist_horiz_m: max_ord_x_m,
346        final_vel_mps: final_vel_mag,
347        final_vel_fps: final_vel_mag * MPS_TO_FPS,
348        final_energy_j,
349        final_energy_ftlbs: final_energy_j * JOULES_TO_FTLBS,
350        air_density_kg_m3: initial_conditions.air_density,
351        speed_of_sound_mps: initial_conditions.speed_of_sound,
352        barrel_angle_rad: initial_conditions.muzzle_angle_rad,
353    })
354}
355
356/// Interpolate trajectory state at a specific time
357fn interpolate_trajectory_state(
358    trajectory_points: &[(f64, [f64; 6])],
359    target_time: f64,
360) -> Result<[f64; 6], String> {
361    if trajectory_points.is_empty() {
362        return Err("No trajectory points available".to_string());
363    }
364    
365    // Find bounding points
366    let mut lower_idx = 0;
367    let mut upper_idx = trajectory_points.len() - 1;
368    
369    for (i, &(time, _)) in trajectory_points.iter().enumerate() {
370        if time <= target_time {
371            lower_idx = i;
372        }
373        if time >= target_time && upper_idx == trajectory_points.len() - 1 {
374            upper_idx = i;
375            break;
376        }
377    }
378    
379    if lower_idx == upper_idx {
380        return Ok(trajectory_points[lower_idx].1);
381    }
382    
383    // Linear interpolation
384    let (t1, state1) = trajectory_points[lower_idx];
385    let (t2, state2) = trajectory_points[upper_idx];
386    
387    if (t2 - t1).abs() < f64::EPSILON {
388        return Ok(state1);
389    }
390    
391    let alpha = (target_time - t1) / (t2 - t1);
392    let mut interpolated_state = [0.0; 6];
393    
394    for i in 0..6 {
395        interpolated_state[i] = state1[i] + alpha * (state2[i] - state1[i]);
396    }
397    
398    Ok(interpolated_state)
399}
400
401#[cfg(test)]
402mod tests {
403    use super::*;
404    
405    #[test]
406    fn test_brent_root_find() {
407        // Test with simple quadratic: x^2 - 4 = 0, root at x = 2
408        let f = |x: f64| x * x - 4.0;
409        let root = brent_root_find(f, 1.0, 3.0, 1e-6, 100).unwrap();
410        assert!((root - 2.0).abs() < 1e-6);
411    }
412    
413    #[test]
414    fn test_interpolate_trajectory_state() {
415        let points = vec![
416            (0.0, [0.0, 0.0, 0.0, 100.0, 50.0, 0.0]),
417            (1.0, [100.0, 45.0, 0.0, 99.0, 40.0, 0.0]),
418            (2.0, [200.0, 80.0, 0.0, 98.0, 30.0, 0.0]),
419        ];
420        
421        let result = interpolate_trajectory_state(&points, 1.5).unwrap();
422        
423        // Should be halfway between points at t=1.0 and t=2.0
424        assert!((result[0] - 150.0).abs() < 1e-10); // x position
425        assert!((result[1] - 62.5).abs() < 1e-10);  // y position
426        assert!((result[3] - 98.5).abs() < 1e-10);  // vx velocity
427    }
428    
429    #[test]
430    fn test_find_trajectory_apex() {
431        let points = vec![
432            (0.0, [0.0, 0.0, 0.0, 100.0, 50.0, 0.0]),
433            (1.0, [100.0, 45.0, 0.0, 99.0, 40.0, 0.0]),
434            (2.0, [200.0, 80.0, 0.0, 98.0, 30.0, 0.0]), // Peak here
435            (3.0, [300.0, 75.0, 0.0, 97.0, 20.0, 0.0]),
436            (4.0, [400.0, 60.0, 0.0, 96.0, 10.0, 0.0]),
437        ];
438        
439        let (max_ord, max_ord_x) = find_trajectory_apex(&points, 500.0);
440        
441        assert!((max_ord - 80.0).abs() < 1e-10);
442        assert!((max_ord_x - 200.0).abs() < 1e-10);
443    }
444}