Skip to main content

ballistics_engine/
trajectory_sampling.rs

1use nalgebra::Vector3;
2use std::collections::HashSet;
3
4/// Trajectory flags for notable events
5#[derive(Debug, Clone, PartialEq, Eq, Hash)]
6pub enum TrajectoryFlag {
7    ZeroCrossing,
8    MachTransition,
9    Apex,
10}
11
12impl TrajectoryFlag {
13    pub fn to_string(&self) -> String {
14        match self {
15            TrajectoryFlag::ZeroCrossing => "zero_crossing".to_string(),
16            TrajectoryFlag::MachTransition => "mach_transition".to_string(),
17            TrajectoryFlag::Apex => "apex".to_string(),
18        }
19    }
20}
21
22/// Single trajectory sample point
23#[derive(Debug, Clone)]
24pub struct TrajectorySample {
25    pub distance_m: f64,
26    pub drop_m: f64,
27    pub wind_drift_m: f64,
28    pub velocity_mps: f64,
29    pub energy_j: f64,
30    pub time_s: f64,
31    pub flags: Vec<TrajectoryFlag>,
32}
33
34/// Trajectory solution data for sampling
35#[derive(Debug, Clone)]
36pub struct TrajectoryData {
37    pub times: Vec<f64>,
38    pub positions: Vec<Vector3<f64>>,  // [x, y, z] positions
39    pub velocities: Vec<Vector3<f64>>, // [vx, vy, vz] velocities
40    pub transonic_distances: Vec<f64>, // Distances where mach transitions occur
41}
42
43/// Output data for trajectory sampling
44#[derive(Debug, Clone)]
45pub struct TrajectoryOutputs {
46    pub target_distance_horiz_m: f64,
47    pub target_vertical_height_m: f64,
48    pub time_of_flight_s: f64,
49    pub max_ord_dist_horiz_m: f64,
50    /// Height of sight above bore (meters). Used for LOS calculation.
51    /// For a flat shot, the LOS is horizontal at y = sight_height_m.
52    pub sight_height_m: f64,
53}
54
55/// Sample trajectory at regular distance intervals with vectorized operations
56pub fn sample_trajectory(
57    trajectory_data: &TrajectoryData,
58    outputs: &TrajectoryOutputs,
59    step_m: f64,
60    mass_kg: f64,
61) -> Vec<TrajectorySample> {
62    let step_size = if step_m <= 0.0 {
63        return Vec::new();
64    } else if step_m < 0.1 {
65        0.1
66    } else {
67        step_m
68    };
69
70    // Use the input target distance as the limit for sampling
71    let max_dist = outputs.target_distance_horiz_m;
72    if max_dist < 1e-9 {
73        return Vec::new();
74    }
75
76    // Extract trajectory arrays for vectorized operations (McCoy: X=downrange, Z=lateral)
77    let downrange_vals: Vec<f64> = trajectory_data.positions.iter().map(|p| p.x).collect();
78    let y_vals: Vec<f64> = trajectory_data.positions.iter().map(|p| p.y).collect();
79    let lateral_vals: Vec<f64> = trajectory_data.positions.iter().map(|p| p.z).collect();
80
81    // Calculate speeds and energies
82    let speeds: Vec<f64> = trajectory_data
83        .velocities
84        .iter()
85        .map(|v| v.norm())
86        .collect();
87    let energies: Vec<f64> = speeds
88        .iter()
89        .map(|&speed| 0.5 * mass_kg * speed * speed)
90        .collect();
91
92    // Generate sampling distances
93    // Calculate number of steps to reach target without exceeding it
94    let num_steps = (max_dist / step_size).ceil() as usize + 1;
95    let distances: Vec<f64> = (0..num_steps)
96        .map(|i| i as f64 * step_size)
97        .filter(|&d| d <= max_dist + 0.1) // Stop exactly at target (with tiny tolerance for rounding)
98        .collect();
99
100    // Vectorized interpolation for all trajectory data
101    let mut samples = Vec::with_capacity(distances.len());
102
103    for &distance in &distances {
104        // Interpolate using X (downrange) as the independent variable
105        // McCoy coordinate system: x=downrange, y=vertical, z=lateral (wind drift)
106        let y_interp = interpolate(&downrange_vals, &y_vals, distance); // vertical at downrange distance
107        let wind_drift = interpolate(&downrange_vals, &lateral_vals, distance); // lateral drift at downrange distance
108        let velocity = interpolate(&downrange_vals, &speeds, distance); // velocity at downrange distance
109        let time = interpolate(&downrange_vals, &trajectory_data.times, distance); // time at downrange distance
110        let energy = interpolate(&downrange_vals, &energies, distance); // energy at downrange distance
111
112        // Calculate line-of-sight y-coordinate and drop
113        // The LOS is a straight line from the SIGHT to the target
114        // The sight is at y = sight_height_m above the bore (which starts at y = 0)
115        // For a flat shot: LOS is horizontal at y = sight_height_m
116        // For elevated/depressed shots: LOS slopes from sight_height_m to target_vertical_height_m
117        //
118        // Drop convention:
119        // - Positive drop means bullet is below LOS (has dropped)
120        // - Negative drop means bullet is above LOS (has risen)
121        // Therefore: drop = LOS - actual (not actual - LOS)
122        //
123        // LOS interpolation: starts at sight_height_m (z=0), ends at target_vertical_height_m (z=max_dist)
124        // Note: For a properly zeroed flat shot, target_vertical_height_m should equal sight_height_m
125        // (bullet ends at LOS at target distance for a point-blank shot)
126        let los_y = outputs.sight_height_m
127            + (outputs.target_vertical_height_m - outputs.sight_height_m) * distance / max_dist;
128        let drop = los_y - y_interp; // LOS - actual: positive when bullet is below LOS
129
130        samples.push(TrajectorySample {
131            distance_m: distance,
132            drop_m: drop,
133            wind_drift_m: wind_drift,
134            velocity_mps: velocity,
135            energy_j: energy,
136            time_s: time,
137            flags: Vec::new(), // Flags will be added later
138        });
139    }
140
141    // Add flags using vectorized detection
142    add_trajectory_flags(&mut samples, &trajectory_data.transonic_distances, max_dist);
143
144    samples
145}
146
147/// Linear interpolation function optimized for trajectory data
148fn interpolate(x_vals: &[f64], y_vals: &[f64], x: f64) -> f64 {
149    if x_vals.is_empty() || y_vals.is_empty() {
150        return 0.0;
151    }
152
153    if x_vals.len() != y_vals.len() {
154        return 0.0;
155    }
156
157    if x <= x_vals[0] {
158        return y_vals[0];
159    }
160
161    if x >= x_vals[x_vals.len() - 1] {
162        return y_vals[y_vals.len() - 1];
163    }
164
165    // Binary search for the correct interval
166    let mut left = 0;
167    let mut right = x_vals.len() - 1;
168
169    while right - left > 1 {
170        let mid = (left + right) / 2;
171        if x_vals[mid] <= x {
172            left = mid;
173        } else {
174            right = mid;
175        }
176    }
177
178    // Linear interpolation
179    let x1 = x_vals[left];
180    let x2 = x_vals[right];
181    let y1 = y_vals[left];
182    let y2 = y_vals[right];
183
184    if (x2 - x1).abs() < f64::EPSILON {
185        return y1;
186    }
187
188    y1 + (y2 - y1) * (x - x1) / (x2 - x1)
189}
190
191/// Add trajectory flags using vectorized detection algorithms
192fn add_trajectory_flags(
193    samples: &mut [TrajectorySample],
194    transonic_distances: &[f64],
195    target_distance_input_m: f64,
196) {
197    let tolerance = 1e-6;
198
199    // 1. Zero crossings - vectorized detection
200    detect_zero_crossings(samples, tolerance);
201
202    // 2. Mach transitions
203    for &transonic_dist in transonic_distances {
204        if let Some(idx) = find_closest_sample_index(samples, transonic_dist) {
205            samples[idx].flags.push(TrajectoryFlag::MachTransition);
206        }
207    }
208
209    // 3. Apex - find the point with maximum height between muzzle and target
210    // Since drop is positive when bullet is below LOS and negative when above,
211    // the apex is where drop is minimum (most negative)
212    if samples.len() > 2 {
213        // Use the target distance passed as parameter
214        let target_distance_m = target_distance_input_m;
215
216        // Find the index of maximum height (minimum drop, most negative) within target distance.
217        // Only mark an interior apex if it is actually above the muzzle/first sample.
218        let first_drop = samples[0].drop_m;
219        let mut min_drop = first_drop;
220        let mut apex_idx: Option<usize> = None;
221
222        // Search from index 1, but stop at target distance
223        for i in 1..samples.len() {
224            // Only consider points up to target distance
225            if samples[i].distance_m > target_distance_m {
226                break;
227            }
228
229            if samples[i].drop_m < min_drop {
230                min_drop = samples[i].drop_m;
231                apex_idx = Some(i);
232            }
233        }
234
235        if let Some(idx) = apex_idx {
236            samples[idx].flags.push(TrajectoryFlag::Apex);
237        }
238    }
239}
240
241/// Detect zero crossings in trajectory drop values using vectorized operations
242fn detect_zero_crossings(samples: &mut [TrajectorySample], tolerance: f64) {
243    if samples.len() < 2 {
244        return;
245    }
246
247    let drops: Vec<f64> = samples.iter().map(|s| s.drop_m).collect();
248
249    // Find crossing indices where drop changes sign
250    for i in 0..(drops.len() - 1) {
251        let current = drops[i];
252        let next = drops[i + 1];
253
254        // Check for sign change crossings
255        let crosses_zero = (current < -tolerance && next >= -tolerance)
256            || (current > tolerance && next <= tolerance);
257
258        if crosses_zero {
259            samples[i + 1].flags.push(TrajectoryFlag::ZeroCrossing);
260        }
261    }
262
263    // Find points very close to zero
264    for (i, &drop) in drops.iter().enumerate() {
265        if drop.abs() <= tolerance {
266            samples[i].flags.push(TrajectoryFlag::ZeroCrossing);
267        }
268    }
269
270    // Remove duplicate zero crossing flags
271    for sample in samples.iter_mut() {
272        let mut unique_flags = Vec::new();
273        let mut seen = HashSet::new();
274
275        for flag in &sample.flags {
276            if seen.insert(flag.clone()) {
277                unique_flags.push(flag.clone());
278            }
279        }
280        sample.flags = unique_flags;
281    }
282}
283
284/// Find the closest sample index to a given distance
285fn find_closest_sample_index(samples: &[TrajectorySample], target_distance: f64) -> Option<usize> {
286    if samples.is_empty() {
287        return None;
288    }
289
290    // Binary search for the closest distance
291    let distances: Vec<f64> = samples.iter().map(|s| s.distance_m).collect();
292
293    let mut left = 0;
294    let mut right = distances.len();
295
296    while left < right {
297        let mid = (left + right) / 2;
298        if distances[mid] < target_distance {
299            left = mid + 1;
300        } else {
301            right = mid;
302        }
303    }
304
305    // Find the closest point (could be left-1 or left)
306    let mut best_idx = left.min(distances.len() - 1);
307
308    if left > 0 {
309        let left_dist = (distances[left - 1] - target_distance).abs();
310        let right_dist = (distances[best_idx] - target_distance).abs();
311
312        // Prefer earlier index in case of tie
313        if left_dist <= right_dist {
314            best_idx = left - 1;
315        }
316    }
317
318    Some(best_idx)
319}
320
321/// Convert trajectory samples to Python-compatible format
322pub fn trajectory_samples_to_dicts(samples: &[TrajectorySample]) -> Vec<TrajectoryDict> {
323    samples
324        .iter()
325        .map(|sample| TrajectoryDict {
326            distance_m: sample.distance_m,
327            drop_m: sample.drop_m,
328            wind_drift_m: sample.wind_drift_m,
329            velocity_mps: sample.velocity_mps,
330            energy_j: sample.energy_j,
331            time_s: sample.time_s,
332            flags: sample.flags.iter().map(|f| f.to_string()).collect(),
333        })
334        .collect()
335}
336
337/// Python-compatible trajectory sample structure
338#[derive(Debug, Clone)]
339pub struct TrajectoryDict {
340    pub distance_m: f64,
341    pub drop_m: f64,
342    pub wind_drift_m: f64,
343    pub velocity_mps: f64,
344    pub energy_j: f64,
345    pub time_s: f64,
346    pub flags: Vec<String>,
347}
348
349#[cfg(test)]
350mod tests {
351    use super::*;
352
353    #[test]
354    fn test_interpolate() {
355        let x_vals = vec![0.0, 1.0, 2.0, 3.0];
356        let y_vals = vec![0.0, 10.0, 20.0, 30.0];
357
358        assert_eq!(interpolate(&x_vals, &y_vals, 0.5), 5.0);
359        assert_eq!(interpolate(&x_vals, &y_vals, 1.5), 15.0);
360        assert_eq!(interpolate(&x_vals, &y_vals, 2.5), 25.0);
361
362        // Test boundary conditions
363        assert_eq!(interpolate(&x_vals, &y_vals, -1.0), 0.0); // Below range
364        assert_eq!(interpolate(&x_vals, &y_vals, 4.0), 30.0); // Above range
365    }
366
367    #[test]
368    fn test_find_closest_sample_index() {
369        let samples = vec![
370            TrajectorySample {
371                distance_m: 0.0,
372                drop_m: 0.0,
373                wind_drift_m: 0.0,
374                velocity_mps: 100.0,
375                energy_j: 1000.0,
376                time_s: 0.0,
377                flags: Vec::new(),
378            },
379            TrajectorySample {
380                distance_m: 10.0,
381                drop_m: -1.0,
382                wind_drift_m: 0.1,
383                velocity_mps: 95.0,
384                energy_j: 950.0,
385                time_s: 0.1,
386                flags: Vec::new(),
387            },
388            TrajectorySample {
389                distance_m: 20.0,
390                drop_m: -4.0,
391                wind_drift_m: 0.2,
392                velocity_mps: 90.0,
393                energy_j: 900.0,
394                time_s: 0.2,
395                flags: Vec::new(),
396            },
397        ];
398
399        assert_eq!(find_closest_sample_index(&samples, 5.0), Some(0));
400        assert_eq!(find_closest_sample_index(&samples, 12.0), Some(1));
401        assert_eq!(find_closest_sample_index(&samples, 18.0), Some(2));
402    }
403
404    #[test]
405    fn test_detect_zero_crossings() {
406        let mut samples = vec![
407            TrajectorySample {
408                distance_m: 0.0,
409                drop_m: 1.0, // Positive
410                wind_drift_m: 0.0,
411                velocity_mps: 100.0,
412                energy_j: 1000.0,
413                time_s: 0.0,
414                flags: Vec::new(),
415            },
416            TrajectorySample {
417                distance_m: 10.0,
418                drop_m: -0.5, // Negative - crossing here
419                wind_drift_m: 0.1,
420                velocity_mps: 95.0,
421                energy_j: 950.0,
422                time_s: 0.1,
423                flags: Vec::new(),
424            },
425            TrajectorySample {
426                distance_m: 20.0,
427                drop_m: -2.0, // Still negative
428                wind_drift_m: 0.2,
429                velocity_mps: 90.0,
430                energy_j: 900.0,
431                time_s: 0.2,
432                flags: Vec::new(),
433            },
434        ];
435
436        detect_zero_crossings(&mut samples, 1e-6);
437
438        // Should have a zero crossing flag at index 1
439        assert!(!samples[0].flags.contains(&TrajectoryFlag::ZeroCrossing));
440        assert!(samples[1].flags.contains(&TrajectoryFlag::ZeroCrossing));
441        assert!(!samples[2].flags.contains(&TrajectoryFlag::ZeroCrossing));
442    }
443
444    #[test]
445    fn test_sample_trajectory_basic() {
446        // Create simple test trajectory data
447        // McCoy coordinate system: x=downrange, y=vertical, z=lateral (wind drift)
448        let trajectory_data = TrajectoryData {
449            times: vec![0.0, 1.0, 2.0],
450            positions: vec![
451                Vector3::new(0.0, 0.0, 0.0), // x=0 (start), y=0 (vertical), z=0 (no drift)
452                Vector3::new(100.0, 10.0, 1.0), // x=100 (mid downrange), y=10 (apex height), z=1 (drift)
453                Vector3::new(200.0, 5.0, 2.0), // x=200 (end downrange), y=5 (below apex), z=2 (drift)
454            ],
455            velocities: vec![
456                Vector3::new(1.0, 10.0, 100.0),
457                Vector3::new(1.0, 5.0, 95.0),
458                Vector3::new(1.0, 0.0, 90.0),
459            ],
460            transonic_distances: vec![150.0],
461        };
462
463        let outputs = TrajectoryOutputs {
464            target_distance_horiz_m: 200.0,
465            target_vertical_height_m: 0.0,
466            time_of_flight_s: 2.0,
467            max_ord_dist_horiz_m: 100.0,
468            sight_height_m: 0.0, // For test: assume bore-referenced coordinates
469        };
470
471        let samples = sample_trajectory(&trajectory_data, &outputs, 50.0, 0.1);
472
473        // Should have samples at 0, 50, 100, 150, 200 meters
474        assert_eq!(samples.len(), 5);
475        assert_eq!(samples[0].distance_m, 0.0);
476        assert_eq!(samples[1].distance_m, 50.0);
477        assert_eq!(samples[2].distance_m, 100.0);
478        assert_eq!(samples[3].distance_m, 150.0);
479        assert_eq!(samples[4].distance_m, 200.0);
480
481        // Check that interpolation is working
482        assert!(samples[1].velocity_mps > 90.0 && samples[1].velocity_mps < 100.0);
483
484        // Check flags
485        assert!(samples[2].flags.contains(&TrajectoryFlag::Apex)); // At apex distance
486        assert!(samples[3].flags.contains(&TrajectoryFlag::MachTransition)); // At transonic distance
487    }
488}