ballistics_engine/
fast_trajectory.rs

1//! Fast trajectory solver for longer ranges.
2//!
3//! This is a Rust implementation of the fast fixed-step trajectory solver
4//! that provides significant performance improvements for long-range calculations.
5
6use crate::{
7    atmosphere::get_local_atmosphere,
8    constants::{GRAINS_TO_KG, G_ACCEL_MPS2, MPS_TO_FPS},
9    drag::get_drag_coefficient,
10    wind::WindSock,
11    BCSegmentData, DragModel, InternalBallisticInputs as BallisticInputs,
12};
13use nalgebra::Vector3;
14
15/// Fast solution container matching Python implementation
16#[derive(Debug, Clone)]
17pub struct FastSolution {
18    /// Time points
19    pub t: Vec<f64>,
20    /// State vectors at each time point [6 x n_points]
21    pub y: Vec<Vec<f64>>,
22    /// Event times [target_hit, max_ord, ground_hit]
23    pub t_events: [Vec<f64>; 3],
24    /// Whether integration succeeded
25    pub success: bool,
26}
27
28impl FastSolution {
29    /// Interpolate solution at time t
30    pub fn sol(&self, t_query: &[f64]) -> Vec<Vec<f64>> {
31        let mut result = vec![vec![0.0; t_query.len()]; 6];
32
33        for (i, &tq) in t_query.iter().enumerate() {
34            // Find the right interval using binary search
35            // Use unwrap_or to safely handle NaN values by treating them as greater
36            let idx = match self
37                .t
38                .binary_search_by(|&t| t.partial_cmp(&tq).unwrap_or(std::cmp::Ordering::Greater))
39            {
40                Ok(idx) => idx,
41                Err(idx) => idx,
42            };
43
44            if idx == 0 {
45                // Before first point
46                for j in 0..6 {
47                    result[j][i] = self.y[j][0];
48                }
49            } else if idx >= self.t.len() {
50                // After last point
51                for j in 0..6 {
52                    result[j][i] = self.y[j][self.t.len() - 1];
53                }
54            } else {
55                // Linear interpolation
56                let t0 = self.t[idx - 1];
57                let t1 = self.t[idx];
58                let frac = (tq - t0) / (t1 - t0);
59
60                for j in 0..6 {
61                    let y0 = self.y[j][idx - 1];
62                    let y1 = self.y[j][idx];
63                    result[j][i] = y0 + frac * (y1 - y0);
64                }
65            }
66        }
67
68        result
69    }
70
71    /// Convert from row-major to column-major format for compatibility
72    pub fn from_trajectory_data(
73        times: Vec<f64>,
74        states: Vec<[f64; 6]>,
75        t_events: [Vec<f64>; 3],
76    ) -> Self {
77        let n_points = times.len();
78        let mut y = vec![vec![0.0; n_points]; 6];
79
80        for (i, state) in states.iter().enumerate() {
81            for j in 0..6 {
82                y[j][i] = state[j];
83            }
84        }
85
86        FastSolution {
87            t: times,
88            y,
89            t_events,
90            success: true,
91        }
92    }
93}
94
95/// Fast trajectory integration parameters
96pub struct FastIntegrationParams {
97    pub horiz: f64,
98    pub vert: f64,
99    pub initial_state: [f64; 6],
100    pub t_span: (f64, f64),
101    pub atmo_params: (f64, f64, f64, f64),
102}
103
104/// Fast fixed-step integration for longer trajectories
105pub fn fast_integrate(
106    inputs: &BallisticInputs,
107    wind_sock: &WindSock,
108    params: FastIntegrationParams,
109) -> FastSolution {
110    // Extract parameters
111    let _mass_kg = inputs.bullet_mass * GRAINS_TO_KG;
112    let bc = inputs.bc_value;
113    let drag_model = &inputs.bc_type;
114
115    // Check for BC segments
116    let has_bc_segments =
117        inputs.bc_segments.is_some() && !inputs.bc_segments.as_ref().unwrap().is_empty();
118    let has_bc_segments_data =
119        inputs.bc_segments_data.is_some() && !inputs.bc_segments_data.as_ref().unwrap().is_empty();
120
121    // Time step - adjust based on distance
122    let dt = if params.horiz > 200.0 {
123        0.001
124    } else if params.horiz > 100.0 {
125        0.0005
126    } else {
127        0.0001
128    };
129
130    // Maximum time based on estimated flight time
131    let v0 = Vector3::new(
132        params.initial_state[3],
133        params.initial_state[4],
134        params.initial_state[5],
135    )
136    .norm();
137
138    let t_max = if v0 > 1e-6 && params.horiz > 0.0 {
139        (2.0 * params.horiz / v0).min(params.t_span.1)
140    } else {
141        params.t_span.1
142    };
143
144    // Initialize arrays
145    let n_steps = ((t_max / dt) as usize) + 1;
146    let mut times = Vec::with_capacity(n_steps);
147    let mut states = Vec::with_capacity(n_steps);
148
149    // Initial state
150    times.push(0.0);
151    states.push(params.initial_state);
152
153    // Get base atmospheric density
154    let (base_density, _) = get_local_atmosphere(
155        0.0,
156        params.atmo_params.0,
157        params.atmo_params.1,
158        params.atmo_params.2,
159        params.atmo_params.3,
160    );
161
162    // Integration loop
163    let mut hit_target = false;
164    let mut hit_ground = false;
165    let mut max_ord_time = None;
166    let mut max_ord_y = 0.0;
167    let ground_threshold = inputs.ground_threshold;
168
169    // RK4 integration
170    for i in 0..n_steps - 1 {
171        let t = i as f64 * dt;
172        let state = states[i];
173
174        let pos = Vector3::new(state[0], state[1], state[2]);
175        let _vel = Vector3::new(state[3], state[4], state[5]);
176
177        // Check termination conditions (z is downrange)
178        if pos.z >= params.horiz {
179            hit_target = true;
180            times.push(t);
181            states.push(state);
182            break;
183        }
184
185        if pos.y <= ground_threshold {
186            hit_ground = true;
187            times.push(t);
188            states.push(state);
189            break;
190        }
191
192        // Track maximum ordinate
193        if pos.y > max_ord_y {
194            max_ord_y = pos.y;
195            max_ord_time = Some(t);
196        }
197
198        // RK4 step
199        let k1 = compute_derivatives(
200            &state,
201            inputs,
202            wind_sock,
203            base_density,
204            drag_model,
205            bc,
206            has_bc_segments,
207            has_bc_segments_data,
208        );
209
210        let mut state2 = state;
211        for j in 0..6 {
212            state2[j] = state[j] + 0.5 * dt * k1[j];
213        }
214        let k2 = compute_derivatives(
215            &state2,
216            inputs,
217            wind_sock,
218            base_density,
219            drag_model,
220            bc,
221            has_bc_segments,
222            has_bc_segments_data,
223        );
224
225        let mut state3 = state;
226        for j in 0..6 {
227            state3[j] = state[j] + 0.5 * dt * k2[j];
228        }
229        let k3 = compute_derivatives(
230            &state3,
231            inputs,
232            wind_sock,
233            base_density,
234            drag_model,
235            bc,
236            has_bc_segments,
237            has_bc_segments_data,
238        );
239
240        let mut state4 = state;
241        for j in 0..6 {
242            state4[j] = state[j] + dt * k3[j];
243        }
244        let k4 = compute_derivatives(
245            &state4,
246            inputs,
247            wind_sock,
248            base_density,
249            drag_model,
250            bc,
251            has_bc_segments,
252            has_bc_segments_data,
253        );
254
255        // Update state
256        let mut new_state = state;
257        for j in 0..6 {
258            new_state[j] = state[j] + dt * (k1[j] + 2.0 * k2[j] + 2.0 * k3[j] + k4[j]) / 6.0;
259        }
260
261        times.push(t + dt);
262        states.push(new_state);
263    }
264
265    // Create event arrays
266    let t_events = [
267        if hit_target {
268            vec![*times.last().unwrap()]
269        } else {
270            vec![]
271        },
272        if let Some(t) = max_ord_time {
273            vec![t]
274        } else {
275            vec![]
276        },
277        if hit_ground {
278            vec![*times.last().unwrap()]
279        } else {
280            vec![]
281        },
282    ];
283
284    FastSolution::from_trajectory_data(times, states, t_events)
285}
286
287/// Compute derivatives for the state vector
288fn compute_derivatives(
289    state: &[f64; 6],
290    inputs: &BallisticInputs,
291    wind_sock: &WindSock,
292    base_density: f64,
293    drag_model: &DragModel,
294    bc: f64,
295    has_bc_segments: bool,
296    has_bc_segments_data: bool,
297) -> [f64; 6] {
298    let pos = Vector3::new(state[0], state[1], state[2]);
299    let vel = Vector3::new(state[3], state[4], state[5]);
300
301    // Get wind vector (based on downrange distance, which is Z coordinate)
302    let wind_vector = wind_sock.vector_for_range_stateless(pos.z);
303
304    // Velocity relative to air
305    let vel_adjusted = vel - wind_vector;
306    let v_mag = vel_adjusted.norm();
307
308    // Calculate acceleration
309    let accel = if v_mag < 1e-6 {
310        Vector3::new(0.0, -G_ACCEL_MPS2, 0.0)
311    } else {
312        // Calculate drag
313        let v_fps = v_mag * MPS_TO_FPS;
314        let mach = v_mag / 340.0; // Approximate speed of sound
315
316        // Get BC value (potentially from segments)
317        let bc_current = if has_bc_segments_data && inputs.bc_segments_data.is_some() {
318            get_bc_from_velocity_segments(v_fps, inputs.bc_segments_data.as_ref().unwrap())
319        } else if has_bc_segments && inputs.bc_segments.is_some() {
320            crate::derivatives::interpolated_bc(
321                mach,
322                inputs.bc_segments.as_ref().unwrap(),
323                Some(inputs),
324            )
325        } else {
326            bc
327        };
328
329        let drag_factor = get_drag_coefficient(mach, drag_model);
330
331        // Calculate drag acceleration using proper ballistics formula
332        let cd_to_retard = 0.000683 * 0.30;
333        let standard_factor = drag_factor * cd_to_retard;
334        let density_scale = base_density / 1.225;
335
336        // Drag acceleration in ft/s^2
337        let a_drag_ft_s2 = (v_fps * v_fps) * standard_factor * density_scale / bc_current;
338
339        // Convert to m/s^2 and apply to velocity vector
340        let a_drag_m_s2 = a_drag_ft_s2 * 0.3048; // ft/s^2 to m/s^2
341        let accel_drag = -a_drag_m_s2 * (vel_adjusted / v_mag);
342
343        // Total acceleration
344        accel_drag + Vector3::new(0.0, -G_ACCEL_MPS2, 0.0)
345    };
346
347    // Return derivatives [vx, vy, vz, ax, ay, az]
348    [vel.x, vel.y, vel.z, accel.x, accel.y, accel.z]
349}
350
351/// Get BC from velocity-based segments
352fn get_bc_from_velocity_segments(velocity_fps: f64, segments: &[BCSegmentData]) -> f64 {
353    for segment in segments {
354        if velocity_fps >= segment.velocity_min && velocity_fps <= segment.velocity_max {
355            return segment.bc_value;
356        }
357    }
358
359    // If no matching segment, use the BC from the closest segment
360    if let Some(first) = segments.first() {
361        if velocity_fps < first.velocity_min {
362            return first.bc_value;
363        }
364    }
365
366    if let Some(last) = segments.last() {
367        if velocity_fps > last.velocity_max {
368            return last.bc_value;
369        }
370    }
371
372    // Fallback (shouldn't reach here if segments are properly defined)
373    0.5
374}
375
376/// Fast integration with explicit wind segments using RK45
377/// MBA-155: Upstreamed from ballistics_rust
378pub fn fast_integrate_with_segments(
379    inputs: &BallisticInputs,
380    wind_segments: Vec<crate::wind::WindSegment>,
381    params: FastIntegrationParams,
382) -> FastSolution {
383    // Use the RK45 implementation from trajectory_integration module
384    use crate::trajectory_integration::{integrate_trajectory, TrajectoryParams};
385
386    // Extract parameters
387    let mass_kg = inputs.bullet_mass * GRAINS_TO_KG;
388    let bc = inputs.bc_value;
389    let drag_model = inputs.bc_type;
390
391    // Get omega vector if advanced effects enabled
392    let omega_vector = if inputs.enable_advanced_effects {
393        // Calculate omega based on latitude
394        let omega_earth = 7.2921159e-5; // rad/s
395        let lat_rad = inputs.latitude.unwrap_or(0.0).to_radians();
396        Some(Vector3::new(
397            0.0,
398            omega_earth * lat_rad.cos(),
399            omega_earth * lat_rad.sin(),
400        ))
401    } else {
402        None
403    };
404
405    // Set up trajectory parameters
406    let traj_params = TrajectoryParams {
407        mass_kg,
408        bc,
409        drag_model,
410        wind_segments,
411        atmos_params: params.atmo_params,
412        omega_vector,
413        enable_spin_drift: inputs.enable_advanced_effects,
414        enable_magnus: inputs.enable_advanced_effects,
415        enable_coriolis: inputs.enable_advanced_effects,
416        target_distance_m: params.horiz,
417        enable_wind_shear: inputs.enable_wind_shear,
418        wind_shear_model: inputs.wind_shear_model.clone(),
419        shooter_altitude_m: inputs.altitude,
420        is_twist_right: inputs.is_twist_right,
421        custom_drag_table: inputs.custom_drag_table.clone(),
422        bc_segments: inputs.bc_segments.clone(),
423        use_bc_segments: inputs.use_bc_segments,
424    };
425
426    // Use RK45 adaptive integration
427    let trajectory = integrate_trajectory(
428        params.initial_state,
429        params.t_span,
430        traj_params,
431        "RK45", // Use RK45 implementation
432        1e-6,   // tolerance
433        0.01,   // max_step
434    );
435
436    // Convert trajectory to FastSolution format
437    let n_points = trajectory.len();
438    let mut times = Vec::with_capacity(n_points);
439    let mut states = Vec::with_capacity(n_points);
440
441    let mut target_hit_time: Option<f64> = None;
442    let mut ground_hit_time: Option<f64> = None;
443    let mut max_ord_time = None;
444    let mut max_ord_y = 0.0;
445
446    for (t, state_vec) in trajectory {
447        // Convert Vector6 to array
448        let state = [
449            state_vec[0],
450            state_vec[1],
451            state_vec[2],
452            state_vec[3],
453            state_vec[4],
454            state_vec[5],
455        ];
456
457        // Check termination conditions
458        // Z IS DOWNRANGE: state[0]=lateral, state[1]=vertical, state[2]=downrange
459
460        // Record FIRST time target is hit
461        if target_hit_time.is_none() && state[2] >= params.horiz {
462            target_hit_time = Some(t);
463        }
464
465        // Record ground hit
466        if ground_hit_time.is_none() && state[1] <= inputs.ground_threshold {
467            ground_hit_time = Some(t);
468        }
469
470        // Track maximum ordinate
471        if state[1] > max_ord_y {
472            max_ord_y = state[1];
473            max_ord_time = Some(t);
474        }
475
476        times.push(t);
477        states.push(state);
478    }
479
480    // Create event arrays
481    let t_events = [
482        if let Some(t) = target_hit_time {
483            vec![t]
484        } else {
485            vec![]
486        },
487        if let Some(t) = max_ord_time {
488            vec![t]
489        } else {
490            vec![]
491        },
492        if let Some(t) = ground_hit_time {
493            vec![t]
494        } else {
495            vec![]
496        },
497    ];
498
499    FastSolution::from_trajectory_data(times, states, t_events)
500}
501
502#[cfg(test)]
503mod tests {
504    use super::*;
505
506    #[test]
507    fn test_fast_solution_interpolation() {
508        let times = vec![0.0, 1.0, 2.0];
509        let states = vec![
510            [0.0, 0.0, 0.0, 100.0, 50.0, 0.0],
511            [100.0, 45.0, 0.0, 99.0, 40.0, 0.0],
512            [198.0, 80.0, 0.0, 98.0, 30.0, 0.0],
513        ];
514
515        let solution = FastSolution::from_trajectory_data(times, states, [vec![], vec![], vec![]]);
516
517        // Test interpolation at t=1.5
518        let result = solution.sol(&[1.5]);
519
520        assert!((result[0][0] - 149.0).abs() < 1e-10); // x position
521        assert!((result[1][0] - 62.5).abs() < 1e-10); // y position
522        assert!((result[3][0] - 98.5).abs() < 1e-10); // vx velocity
523    }
524
525    #[test]
526    fn test_bc_from_velocity_segments() {
527        let segments = vec![
528            BCSegmentData {
529                velocity_min: 0.0,
530                velocity_max: 1000.0,
531                bc_value: 0.5,
532            },
533            BCSegmentData {
534                velocity_min: 1000.0,
535                velocity_max: 2000.0,
536                bc_value: 0.52,
537            },
538            BCSegmentData {
539                velocity_min: 2000.0,
540                velocity_max: 3000.0,
541                bc_value: 0.55,
542            },
543        ];
544
545        assert_eq!(get_bc_from_velocity_segments(500.0, &segments), 0.5);
546        assert_eq!(get_bc_from_velocity_segments(1500.0, &segments), 0.52);
547        assert_eq!(get_bc_from_velocity_segments(2500.0, &segments), 0.55);
548
549        // Test edge cases
550        assert_eq!(get_bc_from_velocity_segments(-100.0, &segments), 0.5); // Below min
551        assert_eq!(get_bc_from_velocity_segments(3500.0, &segments), 0.55); // Above max
552    }
553}