ballistics_engine/
fast_trajectory.rs

1//! Fast trajectory solver for longer ranges.
2//!
3//! This is a Rust implementation of the fast fixed-step trajectory solver
4//! that provides significant performance improvements for long-range calculations.
5
6use nalgebra::Vector3;
7use crate::{
8    InternalBallisticInputs as BallisticInputs, DragModel, BCSegmentData,
9    atmosphere::get_local_atmosphere,
10    drag::get_drag_coefficient,
11    wind::WindSock,
12    constants::{MPS_TO_FPS, GRAINS_TO_KG, G_ACCEL_MPS2},
13};
14
15/// Fast solution container matching Python implementation
16#[derive(Debug, Clone)]
17pub struct FastSolution {
18    /// Time points
19    pub t: Vec<f64>,
20    /// State vectors at each time point [6 x n_points]
21    pub y: Vec<Vec<f64>>,
22    /// Event times [target_hit, max_ord, ground_hit]
23    pub t_events: [Vec<f64>; 3],
24    /// Whether integration succeeded
25    pub success: bool,
26}
27
28impl FastSolution {
29    /// Interpolate solution at time t
30    pub fn sol(&self, t_query: &[f64]) -> Vec<Vec<f64>> {
31        let mut result = vec![vec![0.0; t_query.len()]; 6];
32        
33        for (i, &tq) in t_query.iter().enumerate() {
34            // Find the right interval using binary search
35            let idx = match self.t.binary_search_by(|&t| t.partial_cmp(&tq).unwrap()) {
36                Ok(idx) => idx,
37                Err(idx) => idx,
38            };
39            
40            if idx == 0 {
41                // Before first point
42                for j in 0..6 {
43                    result[j][i] = self.y[j][0];
44                }
45            } else if idx >= self.t.len() {
46                // After last point
47                for j in 0..6 {
48                    result[j][i] = self.y[j][self.t.len() - 1];
49                }
50            } else {
51                // Linear interpolation
52                let t0 = self.t[idx - 1];
53                let t1 = self.t[idx];
54                let frac = (tq - t0) / (t1 - t0);
55                
56                for j in 0..6 {
57                    let y0 = self.y[j][idx - 1];
58                    let y1 = self.y[j][idx];
59                    result[j][i] = y0 + frac * (y1 - y0);
60                }
61            }
62        }
63        
64        result
65    }
66    
67    /// Convert from row-major to column-major format for compatibility
68    pub fn from_trajectory_data(times: Vec<f64>, states: Vec<[f64; 6]>, t_events: [Vec<f64>; 3]) -> Self {
69        let n_points = times.len();
70        let mut y = vec![vec![0.0; n_points]; 6];
71        
72        for (i, state) in states.iter().enumerate() {
73            for j in 0..6 {
74                y[j][i] = state[j];
75            }
76        }
77        
78        FastSolution {
79            t: times,
80            y,
81            t_events,
82            success: true,
83        }
84    }
85}
86
87/// Fast trajectory integration parameters
88pub struct FastIntegrationParams {
89    pub horiz: f64,
90    pub vert: f64,
91    pub initial_state: [f64; 6],
92    pub t_span: (f64, f64),
93    pub atmo_params: (f64, f64, f64, f64),
94}
95
96/// Fast fixed-step integration for longer trajectories
97pub fn fast_integrate(
98    inputs: &BallisticInputs,
99    wind_sock: &WindSock,
100    params: FastIntegrationParams,
101) -> FastSolution {
102    // Extract parameters
103    let _mass_kg = inputs.bullet_mass * GRAINS_TO_KG;
104    let bc = inputs.bc_value;
105    let drag_model = &inputs.bc_type;
106    
107    // Check for BC segments
108    let has_bc_segments = inputs.bc_segments.is_some() && !inputs.bc_segments.as_ref().unwrap().is_empty();
109    let has_bc_segments_data = inputs.bc_segments_data.is_some() && !inputs.bc_segments_data.as_ref().unwrap().is_empty();
110    
111    // Time step - adjust based on distance
112    let dt = if params.horiz > 200.0 {
113        0.001
114    } else if params.horiz > 100.0 {
115        0.0005
116    } else {
117        0.0001
118    };
119    
120    // Maximum time based on estimated flight time
121    let v0 = Vector3::new(
122        params.initial_state[3],
123        params.initial_state[4],
124        params.initial_state[5]
125    ).norm();
126    
127    let t_max = if v0 > 1e-6 && params.horiz > 0.0 {
128        (2.0 * params.horiz / v0).min(params.t_span.1)
129    } else {
130        params.t_span.1
131    };
132    
133    // Initialize arrays
134    let n_steps = ((t_max / dt) as usize) + 1;
135    let mut times = Vec::with_capacity(n_steps);
136    let mut states = Vec::with_capacity(n_steps);
137    
138    // Initial state
139    times.push(0.0);
140    states.push(params.initial_state);
141    
142    // Get base atmospheric density
143    let (base_density, _) = get_local_atmosphere(
144        0.0,
145        params.atmo_params.0,
146        params.atmo_params.1,
147        params.atmo_params.2,
148        params.atmo_params.3,
149    );
150    
151    // Integration loop
152    let mut hit_target = false;
153    let mut hit_ground = false;
154    let mut max_ord_time = None;
155    let mut max_ord_y = 0.0;
156    let ground_threshold = inputs.ground_threshold;
157    
158    // RK4 integration
159    for i in 0..n_steps-1 {
160        let t = i as f64 * dt;
161        let state = states[i];
162        
163        let pos = Vector3::new(state[0], state[1], state[2]);
164        let _vel = Vector3::new(state[3], state[4], state[5]);
165        
166        // Check termination conditions
167        if pos.x >= params.horiz {
168            hit_target = true;
169            times.push(t);
170            states.push(state);
171            break;
172        }
173        
174        if pos.y <= ground_threshold {
175            hit_ground = true;
176            times.push(t);
177            states.push(state);
178            break;
179        }
180        
181        // Track maximum ordinate
182        if pos.y > max_ord_y {
183            max_ord_y = pos.y;
184            max_ord_time = Some(t);
185        }
186        
187        // RK4 step
188        let k1 = compute_derivatives(&state, inputs, wind_sock, base_density, drag_model, bc, has_bc_segments, has_bc_segments_data);
189        
190        let mut state2 = state;
191        for j in 0..6 {
192            state2[j] = state[j] + 0.5 * dt * k1[j];
193        }
194        let k2 = compute_derivatives(&state2, inputs, wind_sock, base_density, drag_model, bc, has_bc_segments, has_bc_segments_data);
195        
196        let mut state3 = state;
197        for j in 0..6 {
198            state3[j] = state[j] + 0.5 * dt * k2[j];
199        }
200        let k3 = compute_derivatives(&state3, inputs, wind_sock, base_density, drag_model, bc, has_bc_segments, has_bc_segments_data);
201        
202        let mut state4 = state;
203        for j in 0..6 {
204            state4[j] = state[j] + dt * k3[j];
205        }
206        let k4 = compute_derivatives(&state4, inputs, wind_sock, base_density, drag_model, bc, has_bc_segments, has_bc_segments_data);
207        
208        // Update state
209        let mut new_state = state;
210        for j in 0..6 {
211            new_state[j] = state[j] + dt * (k1[j] + 2.0 * k2[j] + 2.0 * k3[j] + k4[j]) / 6.0;
212        }
213        
214        times.push(t + dt);
215        states.push(new_state);
216    }
217    
218    // Create event arrays
219    let t_events = [
220        if hit_target { vec![*times.last().unwrap()] } else { vec![] },
221        if let Some(t) = max_ord_time { vec![t] } else { vec![] },
222        if hit_ground { vec![*times.last().unwrap()] } else { vec![] },
223    ];
224    
225    FastSolution::from_trajectory_data(times, states, t_events)
226}
227
228/// Compute derivatives for the state vector
229fn compute_derivatives(
230    state: &[f64; 6],
231    inputs: &BallisticInputs,
232    wind_sock: &WindSock,
233    base_density: f64,
234    drag_model: &DragModel,
235    bc: f64,
236    has_bc_segments: bool,
237    has_bc_segments_data: bool,
238) -> [f64; 6] {
239    let pos = Vector3::new(state[0], state[1], state[2]);
240    let vel = Vector3::new(state[3], state[4], state[5]);
241    
242    // Get wind vector
243    let wind_vector = wind_sock.vector_for_range_stateless(pos.x);
244    
245    // Velocity relative to air
246    let vel_adjusted = vel - wind_vector;
247    let v_mag = vel_adjusted.norm();
248    
249    // Calculate acceleration
250    let accel = if v_mag < 1e-6 {
251        Vector3::new(0.0, -G_ACCEL_MPS2, 0.0)
252    } else {
253        // Calculate drag
254        let v_fps = v_mag * MPS_TO_FPS;
255        let mach = v_mag / 340.0; // Approximate speed of sound
256        
257        // Get BC value (potentially from segments)
258        let bc_current = if has_bc_segments_data && inputs.bc_segments_data.is_some() {
259            get_bc_from_velocity_segments(v_fps, inputs.bc_segments_data.as_ref().unwrap())
260        } else if has_bc_segments && inputs.bc_segments.is_some() {
261            crate::derivatives::interpolated_bc(mach, inputs.bc_segments.as_ref().unwrap(), Some(inputs))
262        } else {
263            bc
264        };
265        
266        let drag_factor = get_drag_coefficient(mach, drag_model);
267        
268        // Calculate drag acceleration using proper ballistics formula
269        let cd_to_retard = 0.000683 * 0.30;
270        let standard_factor = drag_factor * cd_to_retard;
271        let density_scale = base_density / 1.225;
272        
273        // Drag acceleration in ft/s^2
274        let a_drag_ft_s2 = (v_fps * v_fps) * standard_factor * density_scale / bc_current;
275        
276        // Convert to m/s^2 and apply to velocity vector
277        let a_drag_m_s2 = a_drag_ft_s2 * 0.3048; // ft/s^2 to m/s^2
278        let accel_drag = -a_drag_m_s2 * (vel_adjusted / v_mag);
279        
280        // Total acceleration
281        accel_drag + Vector3::new(0.0, -G_ACCEL_MPS2, 0.0)
282    };
283    
284    // Return derivatives [vx, vy, vz, ax, ay, az]
285    [
286        vel.x,
287        vel.y,
288        vel.z,
289        accel.x,
290        accel.y,
291        accel.z,
292    ]
293}
294
295/// Get BC from velocity-based segments
296fn get_bc_from_velocity_segments(velocity_fps: f64, segments: &[BCSegmentData]) -> f64 {
297    for segment in segments {
298        if velocity_fps >= segment.velocity_min && velocity_fps <= segment.velocity_max {
299            return segment.bc_value;
300        }
301    }
302    
303    // If no matching segment, use the BC from the closest segment
304    if let Some(first) = segments.first() {
305        if velocity_fps < first.velocity_min {
306            return first.bc_value;
307        }
308    }
309    
310    if let Some(last) = segments.last() {
311        if velocity_fps > last.velocity_max {
312            return last.bc_value;
313        }
314    }
315    
316    // Fallback (shouldn't reach here if segments are properly defined)
317    0.5
318}
319
320/// Fast integration with explicit wind segments using RK45
321/// MBA-155: Upstreamed from ballistics_rust
322pub fn fast_integrate_with_segments(
323    inputs: &BallisticInputs,
324    wind_segments: Vec<crate::wind::WindSegment>,
325    params: FastIntegrationParams,
326) -> FastSolution {
327    // Use the RK45 implementation from trajectory_integration module
328    use crate::trajectory_integration::{integrate_trajectory, TrajectoryParams};
329
330    // Extract parameters
331    let mass_kg = inputs.bullet_mass * GRAINS_TO_KG;
332    let bc = inputs.bc_value;
333    let drag_model = inputs.bc_type.clone();
334
335    // Get omega vector if advanced effects enabled
336    let omega_vector = if inputs.enable_advanced_effects {
337        // Calculate omega based on latitude
338        let omega_earth = 7.2921159e-5; // rad/s
339        let lat_rad = inputs.latitude.unwrap_or(0.0).to_radians();
340        Some(Vector3::new(
341            0.0,
342            omega_earth * lat_rad.cos(),
343            omega_earth * lat_rad.sin(),
344        ))
345    } else {
346        None
347    };
348
349    // Set up trajectory parameters
350    let traj_params = TrajectoryParams {
351        mass_kg,
352        bc,
353        drag_model,
354        wind_segments,
355        atmos_params: params.atmo_params,
356        omega_vector,
357        enable_spin_drift: inputs.enable_advanced_effects,
358        enable_magnus: inputs.enable_advanced_effects,
359        enable_coriolis: inputs.enable_advanced_effects,
360        target_distance_m: params.horiz,
361        enable_wind_shear: inputs.enable_wind_shear,
362        wind_shear_model: inputs.wind_shear_model.clone(),
363        shooter_altitude_m: inputs.altitude,
364        is_twist_right: inputs.is_twist_right,
365        custom_drag_table: inputs.custom_drag_table.clone(),
366    };
367
368    // Use RK45 adaptive integration
369    let trajectory = integrate_trajectory(
370        params.initial_state,
371        params.t_span,
372        traj_params,
373        "RK45",  // Use RK45 implementation
374        1e-6,    // tolerance
375        0.01,    // max_step
376    );
377
378    // Convert trajectory to FastSolution format
379    let n_points = trajectory.len();
380    let mut times = Vec::with_capacity(n_points);
381    let mut states = Vec::with_capacity(n_points);
382
383    let mut target_hit_time: Option<f64> = None;
384    let mut ground_hit_time: Option<f64> = None;
385    let mut max_ord_time = None;
386    let mut max_ord_y = 0.0;
387
388    for (t, state_vec) in trajectory {
389        // Convert Vector6 to array
390        let state = [
391            state_vec[0], state_vec[1], state_vec[2],
392            state_vec[3], state_vec[4], state_vec[5],
393        ];
394
395        // Check termination conditions
396        // Z IS DOWNRANGE: state[0]=lateral, state[1]=vertical, state[2]=downrange
397
398        // Record FIRST time target is hit
399        if target_hit_time.is_none() && state[2] >= params.horiz {
400            target_hit_time = Some(t);
401        }
402
403        // Record ground hit
404        if ground_hit_time.is_none() && state[1] <= inputs.ground_threshold {
405            ground_hit_time = Some(t);
406        }
407
408        // Track maximum ordinate
409        if state[1] > max_ord_y {
410            max_ord_y = state[1];
411            max_ord_time = Some(t);
412        }
413
414        times.push(t);
415        states.push(state);
416    }
417
418    // Create event arrays
419    let t_events = [
420        if let Some(t) = target_hit_time { vec![t] } else { vec![] },
421        if let Some(t) = max_ord_time { vec![t] } else { vec![] },
422        if let Some(t) = ground_hit_time { vec![t] } else { vec![] },
423    ];
424
425    FastSolution::from_trajectory_data(times, states, t_events)
426}
427
428#[cfg(test)]
429mod tests {
430    use super::*;
431    
432    #[test]
433    fn test_fast_solution_interpolation() {
434        let times = vec![0.0, 1.0, 2.0];
435        let states = vec![
436            [0.0, 0.0, 0.0, 100.0, 50.0, 0.0],
437            [100.0, 45.0, 0.0, 99.0, 40.0, 0.0],
438            [198.0, 80.0, 0.0, 98.0, 30.0, 0.0],
439        ];
440        
441        let solution = FastSolution::from_trajectory_data(times, states, [vec![], vec![], vec![]]);
442        
443        // Test interpolation at t=1.5
444        let result = solution.sol(&[1.5]);
445        
446        assert!((result[0][0] - 149.0).abs() < 1e-10); // x position
447        assert!((result[1][0] - 62.5).abs() < 1e-10);  // y position
448        assert!((result[3][0] - 98.5).abs() < 1e-10);  // vx velocity
449    }
450    
451    #[test]
452    fn test_bc_from_velocity_segments() {
453        let segments = vec![
454            BCSegmentData { velocity_min: 0.0, velocity_max: 1000.0, bc_value: 0.5 },
455            BCSegmentData { velocity_min: 1000.0, velocity_max: 2000.0, bc_value: 0.52 },
456            BCSegmentData { velocity_min: 2000.0, velocity_max: 3000.0, bc_value: 0.55 },
457        ];
458        
459        assert_eq!(get_bc_from_velocity_segments(500.0, &segments), 0.5);
460        assert_eq!(get_bc_from_velocity_segments(1500.0, &segments), 0.52);
461        assert_eq!(get_bc_from_velocity_segments(2500.0, &segments), 0.55);
462        
463        // Test edge cases
464        assert_eq!(get_bc_from_velocity_segments(-100.0, &segments), 0.5); // Below min
465        assert_eq!(get_bc_from_velocity_segments(3500.0, &segments), 0.55); // Above max
466    }
467}