ballistics_engine/
fast_trajectory.rs

1//! Fast trajectory solver for longer ranges.
2//!
3//! This is a Rust implementation of the fast fixed-step trajectory solver
4//! that provides significant performance improvements for long-range calculations.
5
6use nalgebra::Vector3;
7use crate::{
8    InternalBallisticInputs as BallisticInputs, DragModel, BCSegmentData,
9    atmosphere::get_local_atmosphere,
10    drag::get_drag_coefficient,
11    wind::WindSock,
12    constants::{MPS_TO_FPS, GRAINS_TO_KG, G_ACCEL_MPS2},
13};
14
15/// Fast solution container matching Python implementation
16#[derive(Debug, Clone)]
17pub struct FastSolution {
18    /// Time points
19    pub t: Vec<f64>,
20    /// State vectors at each time point [6 x n_points]
21    pub y: Vec<Vec<f64>>,
22    /// Event times [target_hit, max_ord, ground_hit]
23    pub t_events: [Vec<f64>; 3],
24    /// Whether integration succeeded
25    pub success: bool,
26}
27
28impl FastSolution {
29    /// Interpolate solution at time t
30    pub fn sol(&self, t_query: &[f64]) -> Vec<Vec<f64>> {
31        let mut result = vec![vec![0.0; t_query.len()]; 6];
32        
33        for (i, &tq) in t_query.iter().enumerate() {
34            // Find the right interval using binary search
35            // Use unwrap_or to safely handle NaN values by treating them as greater
36            let idx = match self.t.binary_search_by(|&t| t.partial_cmp(&tq).unwrap_or(std::cmp::Ordering::Greater)) {
37                Ok(idx) => idx,
38                Err(idx) => idx,
39            };
40            
41            if idx == 0 {
42                // Before first point
43                for j in 0..6 {
44                    result[j][i] = self.y[j][0];
45                }
46            } else if idx >= self.t.len() {
47                // After last point
48                for j in 0..6 {
49                    result[j][i] = self.y[j][self.t.len() - 1];
50                }
51            } else {
52                // Linear interpolation
53                let t0 = self.t[idx - 1];
54                let t1 = self.t[idx];
55                let frac = (tq - t0) / (t1 - t0);
56                
57                for j in 0..6 {
58                    let y0 = self.y[j][idx - 1];
59                    let y1 = self.y[j][idx];
60                    result[j][i] = y0 + frac * (y1 - y0);
61                }
62            }
63        }
64        
65        result
66    }
67    
68    /// Convert from row-major to column-major format for compatibility
69    pub fn from_trajectory_data(times: Vec<f64>, states: Vec<[f64; 6]>, t_events: [Vec<f64>; 3]) -> Self {
70        let n_points = times.len();
71        let mut y = vec![vec![0.0; n_points]; 6];
72        
73        for (i, state) in states.iter().enumerate() {
74            for j in 0..6 {
75                y[j][i] = state[j];
76            }
77        }
78        
79        FastSolution {
80            t: times,
81            y,
82            t_events,
83            success: true,
84        }
85    }
86}
87
88/// Fast trajectory integration parameters
89pub struct FastIntegrationParams {
90    pub horiz: f64,
91    pub vert: f64,
92    pub initial_state: [f64; 6],
93    pub t_span: (f64, f64),
94    pub atmo_params: (f64, f64, f64, f64),
95}
96
97/// Fast fixed-step integration for longer trajectories
98pub fn fast_integrate(
99    inputs: &BallisticInputs,
100    wind_sock: &WindSock,
101    params: FastIntegrationParams,
102) -> FastSolution {
103    // Extract parameters
104    let _mass_kg = inputs.bullet_mass * GRAINS_TO_KG;
105    let bc = inputs.bc_value;
106    let drag_model = &inputs.bc_type;
107    
108    // Check for BC segments
109    let has_bc_segments = inputs.bc_segments.is_some() && !inputs.bc_segments.as_ref().unwrap().is_empty();
110    let has_bc_segments_data = inputs.bc_segments_data.is_some() && !inputs.bc_segments_data.as_ref().unwrap().is_empty();
111    
112    // Time step - adjust based on distance
113    let dt = if params.horiz > 200.0 {
114        0.001
115    } else if params.horiz > 100.0 {
116        0.0005
117    } else {
118        0.0001
119    };
120    
121    // Maximum time based on estimated flight time
122    let v0 = Vector3::new(
123        params.initial_state[3],
124        params.initial_state[4],
125        params.initial_state[5]
126    ).norm();
127    
128    let t_max = if v0 > 1e-6 && params.horiz > 0.0 {
129        (2.0 * params.horiz / v0).min(params.t_span.1)
130    } else {
131        params.t_span.1
132    };
133    
134    // Initialize arrays
135    let n_steps = ((t_max / dt) as usize) + 1;
136    let mut times = Vec::with_capacity(n_steps);
137    let mut states = Vec::with_capacity(n_steps);
138    
139    // Initial state
140    times.push(0.0);
141    states.push(params.initial_state);
142    
143    // Get base atmospheric density
144    let (base_density, _) = get_local_atmosphere(
145        0.0,
146        params.atmo_params.0,
147        params.atmo_params.1,
148        params.atmo_params.2,
149        params.atmo_params.3,
150    );
151    
152    // Integration loop
153    let mut hit_target = false;
154    let mut hit_ground = false;
155    let mut max_ord_time = None;
156    let mut max_ord_y = 0.0;
157    let ground_threshold = inputs.ground_threshold;
158    
159    // RK4 integration
160    for i in 0..n_steps-1 {
161        let t = i as f64 * dt;
162        let state = states[i];
163        
164        let pos = Vector3::new(state[0], state[1], state[2]);
165        let _vel = Vector3::new(state[3], state[4], state[5]);
166
167        // Check termination conditions (z is downrange)
168        if pos.z >= params.horiz {
169            hit_target = true;
170            times.push(t);
171            states.push(state);
172            break;
173        }
174        
175        if pos.y <= ground_threshold {
176            hit_ground = true;
177            times.push(t);
178            states.push(state);
179            break;
180        }
181        
182        // Track maximum ordinate
183        if pos.y > max_ord_y {
184            max_ord_y = pos.y;
185            max_ord_time = Some(t);
186        }
187        
188        // RK4 step
189        let k1 = compute_derivatives(&state, inputs, wind_sock, base_density, drag_model, bc, has_bc_segments, has_bc_segments_data);
190        
191        let mut state2 = state;
192        for j in 0..6 {
193            state2[j] = state[j] + 0.5 * dt * k1[j];
194        }
195        let k2 = compute_derivatives(&state2, inputs, wind_sock, base_density, drag_model, bc, has_bc_segments, has_bc_segments_data);
196        
197        let mut state3 = state;
198        for j in 0..6 {
199            state3[j] = state[j] + 0.5 * dt * k2[j];
200        }
201        let k3 = compute_derivatives(&state3, inputs, wind_sock, base_density, drag_model, bc, has_bc_segments, has_bc_segments_data);
202        
203        let mut state4 = state;
204        for j in 0..6 {
205            state4[j] = state[j] + dt * k3[j];
206        }
207        let k4 = compute_derivatives(&state4, inputs, wind_sock, base_density, drag_model, bc, has_bc_segments, has_bc_segments_data);
208        
209        // Update state
210        let mut new_state = state;
211        for j in 0..6 {
212            new_state[j] = state[j] + dt * (k1[j] + 2.0 * k2[j] + 2.0 * k3[j] + k4[j]) / 6.0;
213        }
214        
215        times.push(t + dt);
216        states.push(new_state);
217    }
218    
219    // Create event arrays
220    let t_events = [
221        if hit_target { vec![*times.last().unwrap()] } else { vec![] },
222        if let Some(t) = max_ord_time { vec![t] } else { vec![] },
223        if hit_ground { vec![*times.last().unwrap()] } else { vec![] },
224    ];
225    
226    FastSolution::from_trajectory_data(times, states, t_events)
227}
228
229/// Compute derivatives for the state vector
230fn compute_derivatives(
231    state: &[f64; 6],
232    inputs: &BallisticInputs,
233    wind_sock: &WindSock,
234    base_density: f64,
235    drag_model: &DragModel,
236    bc: f64,
237    has_bc_segments: bool,
238    has_bc_segments_data: bool,
239) -> [f64; 6] {
240    let pos = Vector3::new(state[0], state[1], state[2]);
241    let vel = Vector3::new(state[3], state[4], state[5]);
242
243    // Get wind vector (based on downrange distance, which is Z coordinate)
244    let wind_vector = wind_sock.vector_for_range_stateless(pos.z);
245    
246    // Velocity relative to air
247    let vel_adjusted = vel - wind_vector;
248    let v_mag = vel_adjusted.norm();
249    
250    // Calculate acceleration
251    let accel = if v_mag < 1e-6 {
252        Vector3::new(0.0, -G_ACCEL_MPS2, 0.0)
253    } else {
254        // Calculate drag
255        let v_fps = v_mag * MPS_TO_FPS;
256        let mach = v_mag / 340.0; // Approximate speed of sound
257        
258        // Get BC value (potentially from segments)
259        let bc_current = if has_bc_segments_data && inputs.bc_segments_data.is_some() {
260            get_bc_from_velocity_segments(v_fps, inputs.bc_segments_data.as_ref().unwrap())
261        } else if has_bc_segments && inputs.bc_segments.is_some() {
262            crate::derivatives::interpolated_bc(mach, inputs.bc_segments.as_ref().unwrap(), Some(inputs))
263        } else {
264            bc
265        };
266        
267        let drag_factor = get_drag_coefficient(mach, drag_model);
268        
269        // Calculate drag acceleration using proper ballistics formula
270        let cd_to_retard = 0.000683 * 0.30;
271        let standard_factor = drag_factor * cd_to_retard;
272        let density_scale = base_density / 1.225;
273        
274        // Drag acceleration in ft/s^2
275        let a_drag_ft_s2 = (v_fps * v_fps) * standard_factor * density_scale / bc_current;
276        
277        // Convert to m/s^2 and apply to velocity vector
278        let a_drag_m_s2 = a_drag_ft_s2 * 0.3048; // ft/s^2 to m/s^2
279        let accel_drag = -a_drag_m_s2 * (vel_adjusted / v_mag);
280        
281        // Total acceleration
282        accel_drag + Vector3::new(0.0, -G_ACCEL_MPS2, 0.0)
283    };
284    
285    // Return derivatives [vx, vy, vz, ax, ay, az]
286    [
287        vel.x,
288        vel.y,
289        vel.z,
290        accel.x,
291        accel.y,
292        accel.z,
293    ]
294}
295
296/// Get BC from velocity-based segments
297fn get_bc_from_velocity_segments(velocity_fps: f64, segments: &[BCSegmentData]) -> f64 {
298    for segment in segments {
299        if velocity_fps >= segment.velocity_min && velocity_fps <= segment.velocity_max {
300            return segment.bc_value;
301        }
302    }
303    
304    // If no matching segment, use the BC from the closest segment
305    if let Some(first) = segments.first() {
306        if velocity_fps < first.velocity_min {
307            return first.bc_value;
308        }
309    }
310    
311    if let Some(last) = segments.last() {
312        if velocity_fps > last.velocity_max {
313            return last.bc_value;
314        }
315    }
316    
317    // Fallback (shouldn't reach here if segments are properly defined)
318    0.5
319}
320
321/// Fast integration with explicit wind segments using RK45
322/// MBA-155: Upstreamed from ballistics_rust
323pub fn fast_integrate_with_segments(
324    inputs: &BallisticInputs,
325    wind_segments: Vec<crate::wind::WindSegment>,
326    params: FastIntegrationParams,
327) -> FastSolution {
328    // Use the RK45 implementation from trajectory_integration module
329    use crate::trajectory_integration::{integrate_trajectory, TrajectoryParams};
330
331    // Extract parameters
332    let mass_kg = inputs.bullet_mass * GRAINS_TO_KG;
333    let bc = inputs.bc_value;
334    let drag_model = inputs.bc_type.clone();
335
336    // Get omega vector if advanced effects enabled
337    let omega_vector = if inputs.enable_advanced_effects {
338        // Calculate omega based on latitude
339        let omega_earth = 7.2921159e-5; // rad/s
340        let lat_rad = inputs.latitude.unwrap_or(0.0).to_radians();
341        Some(Vector3::new(
342            0.0,
343            omega_earth * lat_rad.cos(),
344            omega_earth * lat_rad.sin(),
345        ))
346    } else {
347        None
348    };
349
350    // Set up trajectory parameters
351    let traj_params = TrajectoryParams {
352        mass_kg,
353        bc,
354        drag_model,
355        wind_segments,
356        atmos_params: params.atmo_params,
357        omega_vector,
358        enable_spin_drift: inputs.enable_advanced_effects,
359        enable_magnus: inputs.enable_advanced_effects,
360        enable_coriolis: inputs.enable_advanced_effects,
361        target_distance_m: params.horiz,
362        enable_wind_shear: inputs.enable_wind_shear,
363        wind_shear_model: inputs.wind_shear_model.clone(),
364        shooter_altitude_m: inputs.altitude,
365        is_twist_right: inputs.is_twist_right,
366        custom_drag_table: inputs.custom_drag_table.clone(),
367        bc_segments: inputs.bc_segments.clone(),
368        use_bc_segments: inputs.use_bc_segments,
369    };
370
371    // Use RK45 adaptive integration
372    let trajectory = integrate_trajectory(
373        params.initial_state,
374        params.t_span,
375        traj_params,
376        "RK45",  // Use RK45 implementation
377        1e-6,    // tolerance
378        0.01,    // max_step
379    );
380
381    // Convert trajectory to FastSolution format
382    let n_points = trajectory.len();
383    let mut times = Vec::with_capacity(n_points);
384    let mut states = Vec::with_capacity(n_points);
385
386    let mut target_hit_time: Option<f64> = None;
387    let mut ground_hit_time: Option<f64> = None;
388    let mut max_ord_time = None;
389    let mut max_ord_y = 0.0;
390
391    for (t, state_vec) in trajectory {
392        // Convert Vector6 to array
393        let state = [
394            state_vec[0], state_vec[1], state_vec[2],
395            state_vec[3], state_vec[4], state_vec[5],
396        ];
397
398        // Check termination conditions
399        // Z IS DOWNRANGE: state[0]=lateral, state[1]=vertical, state[2]=downrange
400
401        // Record FIRST time target is hit
402        if target_hit_time.is_none() && state[2] >= params.horiz {
403            target_hit_time = Some(t);
404        }
405
406        // Record ground hit
407        if ground_hit_time.is_none() && state[1] <= inputs.ground_threshold {
408            ground_hit_time = Some(t);
409        }
410
411        // Track maximum ordinate
412        if state[1] > max_ord_y {
413            max_ord_y = state[1];
414            max_ord_time = Some(t);
415        }
416
417        times.push(t);
418        states.push(state);
419    }
420
421    // Create event arrays
422    let t_events = [
423        if let Some(t) = target_hit_time { vec![t] } else { vec![] },
424        if let Some(t) = max_ord_time { vec![t] } else { vec![] },
425        if let Some(t) = ground_hit_time { vec![t] } else { vec![] },
426    ];
427
428    FastSolution::from_trajectory_data(times, states, t_events)
429}
430
431#[cfg(test)]
432mod tests {
433    use super::*;
434    
435    #[test]
436    fn test_fast_solution_interpolation() {
437        let times = vec![0.0, 1.0, 2.0];
438        let states = vec![
439            [0.0, 0.0, 0.0, 100.0, 50.0, 0.0],
440            [100.0, 45.0, 0.0, 99.0, 40.0, 0.0],
441            [198.0, 80.0, 0.0, 98.0, 30.0, 0.0],
442        ];
443        
444        let solution = FastSolution::from_trajectory_data(times, states, [vec![], vec![], vec![]]);
445        
446        // Test interpolation at t=1.5
447        let result = solution.sol(&[1.5]);
448        
449        assert!((result[0][0] - 149.0).abs() < 1e-10); // x position
450        assert!((result[1][0] - 62.5).abs() < 1e-10);  // y position
451        assert!((result[3][0] - 98.5).abs() < 1e-10);  // vx velocity
452    }
453    
454    #[test]
455    fn test_bc_from_velocity_segments() {
456        let segments = vec![
457            BCSegmentData { velocity_min: 0.0, velocity_max: 1000.0, bc_value: 0.5 },
458            BCSegmentData { velocity_min: 1000.0, velocity_max: 2000.0, bc_value: 0.52 },
459            BCSegmentData { velocity_min: 2000.0, velocity_max: 3000.0, bc_value: 0.55 },
460        ];
461        
462        assert_eq!(get_bc_from_velocity_segments(500.0, &segments), 0.5);
463        assert_eq!(get_bc_from_velocity_segments(1500.0, &segments), 0.52);
464        assert_eq!(get_bc_from_velocity_segments(2500.0, &segments), 0.55);
465        
466        // Test edge cases
467        assert_eq!(get_bc_from_velocity_segments(-100.0, &segments), 0.5); // Below min
468        assert_eq!(get_bc_from_velocity_segments(3500.0, &segments), 0.55); // Above max
469    }
470}