ballistics_engine/
trajectory_sampling.rs

1use nalgebra::Vector3;
2use std::collections::HashSet;
3
4/// Trajectory flags for notable events
5#[derive(Debug, Clone, PartialEq, Eq, Hash)]
6pub enum TrajectoryFlag {
7    ZeroCrossing,
8    MachTransition,
9    Apex,
10}
11
12impl TrajectoryFlag {
13    pub fn to_string(&self) -> String {
14        match self {
15            TrajectoryFlag::ZeroCrossing => "zero_crossing".to_string(),
16            TrajectoryFlag::MachTransition => "mach_transition".to_string(),
17            TrajectoryFlag::Apex => "apex".to_string(),
18        }
19    }
20}
21
22/// Single trajectory sample point
23#[derive(Debug, Clone)]
24pub struct TrajectorySample {
25    pub distance_m: f64,
26    pub drop_m: f64,
27    pub wind_drift_m: f64,
28    pub velocity_mps: f64,
29    pub energy_j: f64,
30    pub time_s: f64,
31    pub flags: Vec<TrajectoryFlag>,
32}
33
34/// Trajectory solution data for sampling
35#[derive(Debug, Clone)]
36pub struct TrajectoryData {
37    pub times: Vec<f64>,
38    pub positions: Vec<Vector3<f64>>,  // [x, y, z] positions
39    pub velocities: Vec<Vector3<f64>>, // [vx, vy, vz] velocities
40    pub transonic_distances: Vec<f64>, // Distances where mach transitions occur
41}
42
43/// Output data for trajectory sampling
44#[derive(Debug, Clone)]
45pub struct TrajectoryOutputs {
46    pub target_distance_horiz_m: f64,
47    pub target_vertical_height_m: f64,
48    pub time_of_flight_s: f64,
49    pub max_ord_dist_horiz_m: f64,
50    /// Height of sight above bore (meters). Used for LOS calculation.
51    /// For a flat shot, the LOS is horizontal at y = sight_height_m.
52    pub sight_height_m: f64,
53}
54
55/// Sample trajectory at regular distance intervals with vectorized operations
56pub fn sample_trajectory(
57    trajectory_data: &TrajectoryData,
58    outputs: &TrajectoryOutputs,
59    step_m: f64,
60    mass_kg: f64,
61) -> Vec<TrajectorySample> {
62    let step_size = if step_m <= 0.0 {
63        return Vec::new();
64    } else if step_m < 0.1 {
65        0.1
66    } else {
67        step_m
68    };
69
70    // Use the input target distance as the limit for sampling
71    let max_dist = outputs.target_distance_horiz_m;
72    if max_dist < 1e-9 {
73        return Vec::new();
74    }
75
76    // Extract trajectory arrays for vectorized operations
77    let x_vals: Vec<f64> = trajectory_data.positions.iter().map(|p| p.x).collect();
78    let y_vals: Vec<f64> = trajectory_data.positions.iter().map(|p| p.y).collect();
79    let z_vals: Vec<f64> = trajectory_data.positions.iter().map(|p| p.z).collect();
80
81    // Calculate speeds and energies
82    let speeds: Vec<f64> = trajectory_data
83        .velocities
84        .iter()
85        .map(|v| v.norm())
86        .collect();
87    let energies: Vec<f64> = speeds
88        .iter()
89        .map(|&speed| 0.5 * mass_kg * speed * speed)
90        .collect();
91
92    // Generate sampling distances
93    // Calculate number of steps to reach target without exceeding it
94    let num_steps = (max_dist / step_size).ceil() as usize + 1;
95    let distances: Vec<f64> = (0..num_steps)
96        .map(|i| i as f64 * step_size)
97        .filter(|&d| d <= max_dist + 0.1) // Stop exactly at target (with tiny tolerance for rounding)
98        .collect();
99
100    // Vectorized interpolation for all trajectory data
101    let mut samples = Vec::with_capacity(distances.len());
102
103    for &distance in &distances {
104        // Interpolate using z (downrange) as the independent variable
105        // Coordinate system: x=lateral (wind drift), y=vertical, z=downrange
106        let y_interp = interpolate(&z_vals, &y_vals, distance); // vertical at downrange distance
107        let wind_drift = interpolate(&z_vals, &x_vals, distance); // lateral drift at downrange distance
108        let velocity = interpolate(&z_vals, &speeds, distance); // velocity at downrange distance
109        let time = interpolate(&z_vals, &trajectory_data.times, distance); // time at downrange distance
110        let energy = interpolate(&z_vals, &energies, distance); // energy at downrange distance
111
112        // Calculate line-of-sight y-coordinate and drop
113        // The LOS is a straight line from the SIGHT to the target
114        // The sight is at y = sight_height_m above the bore (which starts at y = 0)
115        // For a flat shot: LOS is horizontal at y = sight_height_m
116        // For elevated/depressed shots: LOS slopes from sight_height_m to target_vertical_height_m
117        //
118        // Drop convention:
119        // - Positive drop means bullet is below LOS (has dropped)
120        // - Negative drop means bullet is above LOS (has risen)
121        // Therefore: drop = LOS - actual (not actual - LOS)
122        //
123        // LOS interpolation: starts at sight_height_m (z=0), ends at target_vertical_height_m (z=max_dist)
124        // Note: For a properly zeroed flat shot, target_vertical_height_m should equal sight_height_m
125        // (bullet ends at LOS at target distance for a point-blank shot)
126        let los_y = outputs.sight_height_m
127            + (outputs.target_vertical_height_m - outputs.sight_height_m) * distance / max_dist;
128        let drop = los_y - y_interp; // LOS - actual: positive when bullet is below LOS
129
130        samples.push(TrajectorySample {
131            distance_m: distance,
132            drop_m: drop,
133            wind_drift_m: wind_drift,
134            velocity_mps: velocity,
135            energy_j: energy,
136            time_s: time,
137            flags: Vec::new(), // Flags will be added later
138        });
139    }
140
141    // Add flags using vectorized detection
142    add_trajectory_flags(&mut samples, &trajectory_data.transonic_distances, max_dist);
143
144    samples
145}
146
147/// Linear interpolation function optimized for trajectory data
148fn interpolate(x_vals: &[f64], y_vals: &[f64], x: f64) -> f64 {
149    if x_vals.is_empty() || y_vals.is_empty() {
150        return 0.0;
151    }
152
153    if x_vals.len() != y_vals.len() {
154        return 0.0;
155    }
156
157    if x <= x_vals[0] {
158        return y_vals[0];
159    }
160
161    if x >= x_vals[x_vals.len() - 1] {
162        return y_vals[y_vals.len() - 1];
163    }
164
165    // Binary search for the correct interval
166    let mut left = 0;
167    let mut right = x_vals.len() - 1;
168
169    while right - left > 1 {
170        let mid = (left + right) / 2;
171        if x_vals[mid] <= x {
172            left = mid;
173        } else {
174            right = mid;
175        }
176    }
177
178    // Linear interpolation
179    let x1 = x_vals[left];
180    let x2 = x_vals[right];
181    let y1 = y_vals[left];
182    let y2 = y_vals[right];
183
184    if (x2 - x1).abs() < f64::EPSILON {
185        return y1;
186    }
187
188    y1 + (y2 - y1) * (x - x1) / (x2 - x1)
189}
190
191/// Add trajectory flags using vectorized detection algorithms
192fn add_trajectory_flags(
193    samples: &mut [TrajectorySample],
194    transonic_distances: &[f64],
195    target_distance_input_m: f64,
196) {
197    let tolerance = 1e-6;
198
199    // 1. Zero crossings - vectorized detection
200    detect_zero_crossings(samples, tolerance);
201
202    // 2. Mach transitions
203    for &transonic_dist in transonic_distances {
204        if let Some(idx) = find_closest_sample_index(samples, transonic_dist) {
205            samples[idx].flags.push(TrajectoryFlag::MachTransition);
206        }
207    }
208
209    // 3. Apex - find the point with maximum height between muzzle and target
210    // Since drop is positive when bullet is below LOS and negative when above,
211    // the apex is where drop is minimum (most negative)
212    if samples.len() > 2 {
213        // Use the target distance passed as parameter
214        let target_distance_m = target_distance_input_m;
215
216        // Find the index of maximum height (minimum drop, most negative) within target distance
217        // Exclude first point (always 0 for auto-zeroing)
218        let mut min_drop = f64::INFINITY;
219        let mut apex_idx = 1;
220
221        // Search from index 1, but stop at target distance
222        for i in 1..samples.len() {
223            // Only consider points up to target distance
224            if samples[i].distance_m > target_distance_m {
225                break;
226            }
227
228            if samples[i].drop_m < min_drop {
229                min_drop = samples[i].drop_m;
230                apex_idx = i;
231            }
232        }
233
234        // Mark the apex
235        samples[apex_idx].flags.push(TrajectoryFlag::Apex);
236    }
237}
238
239/// Detect zero crossings in trajectory drop values using vectorized operations
240fn detect_zero_crossings(samples: &mut [TrajectorySample], tolerance: f64) {
241    if samples.len() < 2 {
242        return;
243    }
244
245    let drops: Vec<f64> = samples.iter().map(|s| s.drop_m).collect();
246
247    // Find crossing indices where drop changes sign
248    for i in 0..(drops.len() - 1) {
249        let current = drops[i];
250        let next = drops[i + 1];
251
252        // Check for sign change crossings
253        let crosses_zero = (current < -tolerance && next >= -tolerance)
254            || (current > tolerance && next <= tolerance);
255
256        if crosses_zero {
257            samples[i + 1].flags.push(TrajectoryFlag::ZeroCrossing);
258        }
259    }
260
261    // Find points very close to zero
262    for (i, &drop) in drops.iter().enumerate() {
263        if drop.abs() <= tolerance {
264            samples[i].flags.push(TrajectoryFlag::ZeroCrossing);
265        }
266    }
267
268    // Remove duplicate zero crossing flags
269    for sample in samples.iter_mut() {
270        let mut unique_flags = Vec::new();
271        let mut seen = HashSet::new();
272
273        for flag in &sample.flags {
274            if seen.insert(flag.clone()) {
275                unique_flags.push(flag.clone());
276            }
277        }
278        sample.flags = unique_flags;
279    }
280}
281
282/// Find the closest sample index to a given distance
283fn find_closest_sample_index(samples: &[TrajectorySample], target_distance: f64) -> Option<usize> {
284    if samples.is_empty() {
285        return None;
286    }
287
288    // Binary search for the closest distance
289    let distances: Vec<f64> = samples.iter().map(|s| s.distance_m).collect();
290
291    let mut left = 0;
292    let mut right = distances.len();
293
294    while left < right {
295        let mid = (left + right) / 2;
296        if distances[mid] < target_distance {
297            left = mid + 1;
298        } else {
299            right = mid;
300        }
301    }
302
303    // Find the closest point (could be left-1 or left)
304    let mut best_idx = left.min(distances.len() - 1);
305
306    if left > 0 {
307        let left_dist = (distances[left - 1] - target_distance).abs();
308        let right_dist = (distances[best_idx] - target_distance).abs();
309
310        // Prefer earlier index in case of tie
311        if left_dist <= right_dist {
312            best_idx = left - 1;
313        }
314    }
315
316    Some(best_idx)
317}
318
319/// Convert trajectory samples to Python-compatible format
320pub fn trajectory_samples_to_dicts(samples: &[TrajectorySample]) -> Vec<TrajectoryDict> {
321    samples
322        .iter()
323        .map(|sample| TrajectoryDict {
324            distance_m: sample.distance_m,
325            drop_m: sample.drop_m,
326            wind_drift_m: sample.wind_drift_m,
327            velocity_mps: sample.velocity_mps,
328            energy_j: sample.energy_j,
329            time_s: sample.time_s,
330            flags: sample.flags.iter().map(|f| f.to_string()).collect(),
331        })
332        .collect()
333}
334
335/// Python-compatible trajectory sample structure
336#[derive(Debug, Clone)]
337pub struct TrajectoryDict {
338    pub distance_m: f64,
339    pub drop_m: f64,
340    pub wind_drift_m: f64,
341    pub velocity_mps: f64,
342    pub energy_j: f64,
343    pub time_s: f64,
344    pub flags: Vec<String>,
345}
346
347#[cfg(test)]
348mod tests {
349    use super::*;
350
351    #[test]
352    fn test_interpolate() {
353        let x_vals = vec![0.0, 1.0, 2.0, 3.0];
354        let y_vals = vec![0.0, 10.0, 20.0, 30.0];
355
356        assert_eq!(interpolate(&x_vals, &y_vals, 0.5), 5.0);
357        assert_eq!(interpolate(&x_vals, &y_vals, 1.5), 15.0);
358        assert_eq!(interpolate(&x_vals, &y_vals, 2.5), 25.0);
359
360        // Test boundary conditions
361        assert_eq!(interpolate(&x_vals, &y_vals, -1.0), 0.0); // Below range
362        assert_eq!(interpolate(&x_vals, &y_vals, 4.0), 30.0); // Above range
363    }
364
365    #[test]
366    fn test_find_closest_sample_index() {
367        let samples = vec![
368            TrajectorySample {
369                distance_m: 0.0,
370                drop_m: 0.0,
371                wind_drift_m: 0.0,
372                velocity_mps: 100.0,
373                energy_j: 1000.0,
374                time_s: 0.0,
375                flags: Vec::new(),
376            },
377            TrajectorySample {
378                distance_m: 10.0,
379                drop_m: -1.0,
380                wind_drift_m: 0.1,
381                velocity_mps: 95.0,
382                energy_j: 950.0,
383                time_s: 0.1,
384                flags: Vec::new(),
385            },
386            TrajectorySample {
387                distance_m: 20.0,
388                drop_m: -4.0,
389                wind_drift_m: 0.2,
390                velocity_mps: 90.0,
391                energy_j: 900.0,
392                time_s: 0.2,
393                flags: Vec::new(),
394            },
395        ];
396
397        assert_eq!(find_closest_sample_index(&samples, 5.0), Some(0));
398        assert_eq!(find_closest_sample_index(&samples, 12.0), Some(1));
399        assert_eq!(find_closest_sample_index(&samples, 18.0), Some(2));
400    }
401
402    #[test]
403    fn test_detect_zero_crossings() {
404        let mut samples = vec![
405            TrajectorySample {
406                distance_m: 0.0,
407                drop_m: 1.0, // Positive
408                wind_drift_m: 0.0,
409                velocity_mps: 100.0,
410                energy_j: 1000.0,
411                time_s: 0.0,
412                flags: Vec::new(),
413            },
414            TrajectorySample {
415                distance_m: 10.0,
416                drop_m: -0.5, // Negative - crossing here
417                wind_drift_m: 0.1,
418                velocity_mps: 95.0,
419                energy_j: 950.0,
420                time_s: 0.1,
421                flags: Vec::new(),
422            },
423            TrajectorySample {
424                distance_m: 20.0,
425                drop_m: -2.0, // Still negative
426                wind_drift_m: 0.2,
427                velocity_mps: 90.0,
428                energy_j: 900.0,
429                time_s: 0.2,
430                flags: Vec::new(),
431            },
432        ];
433
434        detect_zero_crossings(&mut samples, 1e-6);
435
436        // Should have a zero crossing flag at index 1
437        assert!(!samples[0].flags.contains(&TrajectoryFlag::ZeroCrossing));
438        assert!(samples[1].flags.contains(&TrajectoryFlag::ZeroCrossing));
439        assert!(!samples[2].flags.contains(&TrajectoryFlag::ZeroCrossing));
440    }
441
442    #[test]
443    fn test_sample_trajectory_basic() {
444        // Create simple test trajectory data
445        // Coordinate system: x=lateral (wind drift), y=vertical, z=downrange
446        let trajectory_data = TrajectoryData {
447            times: vec![0.0, 1.0, 2.0],
448            positions: vec![
449                Vector3::new(0.0, 0.0, 0.0),     // x=0 (no drift), y=0 (vertical), z=0 (start)
450                Vector3::new(1.0, 10.0, 100.0),  // x=1 (drift), y=10 (apex height), z=100 (mid)
451                Vector3::new(2.0, 5.0, 200.0),   // x=2 (drift), y=5 (below apex), z=200 (end)
452            ],
453            velocities: vec![
454                Vector3::new(1.0, 10.0, 100.0),
455                Vector3::new(1.0, 5.0, 95.0),
456                Vector3::new(1.0, 0.0, 90.0),
457            ],
458            transonic_distances: vec![150.0],
459        };
460
461        let outputs = TrajectoryOutputs {
462            target_distance_horiz_m: 200.0,
463            target_vertical_height_m: 0.0,
464            time_of_flight_s: 2.0,
465            max_ord_dist_horiz_m: 100.0,
466            sight_height_m: 0.0, // For test: assume bore-referenced coordinates
467        };
468
469        let samples = sample_trajectory(&trajectory_data, &outputs, 50.0, 0.1);
470
471        // Should have samples at 0, 50, 100, 150, 200 meters
472        assert_eq!(samples.len(), 5);
473        assert_eq!(samples[0].distance_m, 0.0);
474        assert_eq!(samples[1].distance_m, 50.0);
475        assert_eq!(samples[2].distance_m, 100.0);
476        assert_eq!(samples[3].distance_m, 150.0);
477        assert_eq!(samples[4].distance_m, 200.0);
478
479        // Check that interpolation is working
480        assert!(samples[1].velocity_mps > 90.0 && samples[1].velocity_mps < 100.0);
481
482        // Check flags
483        assert!(samples[2].flags.contains(&TrajectoryFlag::Apex)); // At apex distance
484        assert!(samples[3].flags.contains(&TrajectoryFlag::MachTransition)); // At transonic distance
485    }
486}