ballistics_engine/
trajectory_sampling.rs

1use nalgebra::Vector3;
2use std::collections::HashSet;
3
4/// Trajectory flags for notable events
5#[derive(Debug, Clone, PartialEq, Eq, Hash)]
6pub enum TrajectoryFlag {
7    ZeroCrossing,
8    MachTransition,
9    Apex,
10}
11
12impl TrajectoryFlag {
13    pub fn to_string(&self) -> String {
14        match self {
15            TrajectoryFlag::ZeroCrossing => "zero_crossing".to_string(),
16            TrajectoryFlag::MachTransition => "mach_transition".to_string(),
17            TrajectoryFlag::Apex => "apex".to_string(),
18        }
19    }
20}
21
22/// Single trajectory sample point
23#[derive(Debug, Clone)]
24pub struct TrajectorySample {
25    pub distance_m: f64,
26    pub drop_m: f64,
27    pub wind_drift_m: f64,
28    pub velocity_mps: f64,
29    pub energy_j: f64,
30    pub time_s: f64,
31    pub flags: Vec<TrajectoryFlag>,
32}
33
34/// Trajectory solution data for sampling
35#[derive(Debug, Clone)]
36pub struct TrajectoryData {
37    pub times: Vec<f64>,
38    pub positions: Vec<Vector3<f64>>,  // [x, y, z] positions
39    pub velocities: Vec<Vector3<f64>>, // [vx, vy, vz] velocities
40    pub transonic_distances: Vec<f64>, // Distances where mach transitions occur
41}
42
43/// Output data for trajectory sampling
44#[derive(Debug, Clone)]
45pub struct TrajectoryOutputs {
46    pub target_distance_horiz_m: f64,
47    pub target_vertical_height_m: f64,
48    pub time_of_flight_s: f64,
49    pub max_ord_dist_horiz_m: f64,
50}
51
52/// Sample trajectory at regular distance intervals with vectorized operations
53pub fn sample_trajectory(
54    trajectory_data: &TrajectoryData,
55    outputs: &TrajectoryOutputs,
56    step_m: f64,
57    mass_kg: f64,
58) -> Vec<TrajectorySample> {
59    let step_size = if step_m <= 0.0 {
60        return Vec::new();
61    } else if step_m < 0.1 {
62        0.1
63    } else {
64        step_m
65    };
66    
67    // Use the input target distance as the limit for sampling
68    let max_dist = outputs.target_distance_horiz_m;
69    if max_dist < 1e-9 {
70        return Vec::new();
71    }
72    
73    // Extract trajectory arrays for vectorized operations
74    let x_vals: Vec<f64> = trajectory_data.positions.iter().map(|p| p.x).collect();
75    let y_vals: Vec<f64> = trajectory_data.positions.iter().map(|p| p.y).collect();
76    let z_vals: Vec<f64> = trajectory_data.positions.iter().map(|p| p.z).collect();
77    
78    // Calculate speeds and energies
79    let speeds: Vec<f64> = trajectory_data.velocities.iter()
80        .map(|v| v.norm())
81        .collect();
82    let energies: Vec<f64> = speeds.iter()
83        .map(|&speed| 0.5 * mass_kg * speed * speed)
84        .collect();
85    
86    // Generate sampling distances
87    // Calculate number of steps to reach target without exceeding it
88    let num_steps = (max_dist / step_size).ceil() as usize + 1;
89    let distances: Vec<f64> = (0..num_steps)
90        .map(|i| i as f64 * step_size)
91        .filter(|&d| d <= max_dist + 0.1)  // Stop exactly at target (with tiny tolerance for rounding)
92        .collect();
93    
94    // Vectorized interpolation for all trajectory data
95    let mut samples = Vec::with_capacity(distances.len());
96    
97    // Get initial height (muzzle height) for proper LOS calculation
98    let muzzle_y = if !y_vals.is_empty() { y_vals[0] } else { 0.0 };
99    
100    for &distance in &distances {
101        // Interpolate using x (downrange) as the independent variable - STANDARD BALLISTICS CONVENTION
102        let y_interp = interpolate(&x_vals, &y_vals, distance);  // vertical at downrange distance
103        let wind_drift = interpolate(&x_vals, &z_vals, distance);  // lateral drift at downrange distance
104        let velocity = interpolate(&x_vals, &speeds, distance);  // velocity at downrange distance
105        let time = interpolate(&x_vals, &trajectory_data.times, distance);  // time at downrange distance
106        let energy = interpolate(&x_vals, &energies, distance);  // energy at downrange distance
107        
108        // Calculate line-of-sight y-coordinate and drop
109        // The LOS is the straight line from initial position to target
110        // For coordinate shots: goes from muzzle_y to target_vertical_height_m
111        // Drop convention:
112        // - Positive drop means bullet is below LOS (has dropped)
113        // - Negative drop means bullet is above LOS (has risen)
114        // Therefore: drop = LOS - actual (not actual - LOS)
115        let los_y = muzzle_y + (outputs.target_vertical_height_m - muzzle_y) * distance / max_dist;
116        let drop = los_y - y_interp;  // LOS - actual: positive when bullet is below LOS
117        
118        samples.push(TrajectorySample {
119            distance_m: distance,
120            drop_m: drop,
121            wind_drift_m: wind_drift,
122            velocity_mps: velocity,
123            energy_j: energy,
124            time_s: time,
125            flags: Vec::new(), // Flags will be added later
126        });
127    }
128    
129    // Add flags using vectorized detection
130    add_trajectory_flags(&mut samples, &trajectory_data.transonic_distances, max_dist);
131    
132    samples
133}
134
135/// Linear interpolation function optimized for trajectory data
136fn interpolate(x_vals: &[f64], y_vals: &[f64], x: f64) -> f64 {
137    if x_vals.is_empty() || y_vals.is_empty() {
138        return 0.0;
139    }
140    
141    if x_vals.len() != y_vals.len() {
142        return 0.0;
143    }
144    
145    if x <= x_vals[0] {
146        return y_vals[0];
147    }
148    
149    if x >= x_vals[x_vals.len() - 1] {
150        return y_vals[y_vals.len() - 1];
151    }
152    
153    // Binary search for the correct interval
154    let mut left = 0;
155    let mut right = x_vals.len() - 1;
156    
157    while right - left > 1 {
158        let mid = (left + right) / 2;
159        if x_vals[mid] <= x {
160            left = mid;
161        } else {
162            right = mid;
163        }
164    }
165    
166    // Linear interpolation
167    let x1 = x_vals[left];
168    let x2 = x_vals[right];
169    let y1 = y_vals[left];
170    let y2 = y_vals[right];
171    
172    if (x2 - x1).abs() < f64::EPSILON {
173        return y1;
174    }
175    
176    y1 + (y2 - y1) * (x - x1) / (x2 - x1)
177}
178
179/// Add trajectory flags using vectorized detection algorithms
180fn add_trajectory_flags(
181    samples: &mut [TrajectorySample],
182    transonic_distances: &[f64],
183    target_distance_input_m: f64,
184) {
185    let tolerance = 1e-6;
186    
187    // 1. Zero crossings - vectorized detection
188    detect_zero_crossings(samples, tolerance);
189    
190    // 2. Mach transitions
191    for &transonic_dist in transonic_distances {
192        if let Some(idx) = find_closest_sample_index(samples, transonic_dist) {
193            samples[idx].flags.push(TrajectoryFlag::MachTransition);
194        }
195    }
196    
197    // 3. Apex - find the point with maximum height between muzzle and target
198    // Since drop is positive when bullet is below LOS and negative when above,
199    // the apex is where drop is minimum (most negative)
200    if samples.len() > 2 {
201        // Use the target distance passed as parameter
202        let target_distance_m = target_distance_input_m;
203        
204        // Find the index of maximum height (minimum drop, most negative) within target distance
205        // Exclude first point (always 0 for auto-zeroing)
206        let mut min_drop = f64::INFINITY;
207        let mut apex_idx = 1;
208
209        // Search from index 1, but stop at target distance
210        for i in 1..samples.len() {
211            // Only consider points up to target distance
212            if samples[i].distance_m > target_distance_m {
213                break;
214            }
215
216            if samples[i].drop_m < min_drop {
217                min_drop = samples[i].drop_m;
218                apex_idx = i;
219            }
220        }
221        
222        // Mark the apex
223        samples[apex_idx].flags.push(TrajectoryFlag::Apex);
224    }
225}
226
227/// Detect zero crossings in trajectory drop values using vectorized operations
228fn detect_zero_crossings(samples: &mut [TrajectorySample], tolerance: f64) {
229    if samples.len() < 2 {
230        return;
231    }
232    
233    let drops: Vec<f64> = samples.iter().map(|s| s.drop_m).collect();
234    
235    // Find crossing indices where drop changes sign
236    for i in 0..(drops.len() - 1) {
237        let current = drops[i];
238        let next = drops[i + 1];
239        
240        // Check for sign change crossings
241        let crosses_zero = (current < -tolerance && next >= -tolerance) ||
242                          (current > tolerance && next <= tolerance);
243        
244        if crosses_zero {
245            samples[i + 1].flags.push(TrajectoryFlag::ZeroCrossing);
246        }
247    }
248    
249    // Find points very close to zero
250    for (i, &drop) in drops.iter().enumerate() {
251        if drop.abs() <= tolerance {
252            samples[i].flags.push(TrajectoryFlag::ZeroCrossing);
253        }
254    }
255    
256    // Remove duplicate zero crossing flags
257    for sample in samples.iter_mut() {
258        let mut unique_flags = Vec::new();
259        let mut seen = HashSet::new();
260        
261        for flag in &sample.flags {
262            if seen.insert(flag.clone()) {
263                unique_flags.push(flag.clone());
264            }
265        }
266        sample.flags = unique_flags;
267    }
268}
269
270/// Find the closest sample index to a given distance
271fn find_closest_sample_index(samples: &[TrajectorySample], target_distance: f64) -> Option<usize> {
272    if samples.is_empty() {
273        return None;
274    }
275    
276    // Binary search for the closest distance
277    let distances: Vec<f64> = samples.iter().map(|s| s.distance_m).collect();
278    
279    let mut left = 0;
280    let mut right = distances.len();
281    
282    while left < right {
283        let mid = (left + right) / 2;
284        if distances[mid] < target_distance {
285            left = mid + 1;
286        } else {
287            right = mid;
288        }
289    }
290    
291    // Find the closest point (could be left-1 or left)
292    let mut best_idx = left.min(distances.len() - 1);
293    
294    if left > 0 {
295        let left_dist = (distances[left - 1] - target_distance).abs();
296        let right_dist = (distances[best_idx] - target_distance).abs();
297
298        // Prefer earlier index in case of tie
299        if left_dist <= right_dist {
300            best_idx = left - 1;
301        }
302    }
303    
304    Some(best_idx)
305}
306
307/// Convert trajectory samples to Python-compatible format
308pub fn trajectory_samples_to_dicts(samples: &[TrajectorySample]) -> Vec<TrajectoryDict> {
309    samples.iter().map(|sample| {
310        TrajectoryDict {
311            distance_m: sample.distance_m,
312            drop_m: sample.drop_m,
313            wind_drift_m: sample.wind_drift_m,
314            velocity_mps: sample.velocity_mps,
315            energy_j: sample.energy_j,
316            time_s: sample.time_s,
317            flags: sample.flags.iter().map(|f| f.to_string()).collect(),
318        }
319    }).collect()
320}
321
322/// Python-compatible trajectory sample structure
323#[derive(Debug, Clone)]
324pub struct TrajectoryDict {
325    pub distance_m: f64,
326    pub drop_m: f64,
327    pub wind_drift_m: f64,
328    pub velocity_mps: f64,
329    pub energy_j: f64,
330    pub time_s: f64,
331    pub flags: Vec<String>,
332}
333
334#[cfg(test)]
335mod tests {
336    use super::*;
337    
338    #[test]
339    fn test_interpolate() {
340        let x_vals = vec![0.0, 1.0, 2.0, 3.0];
341        let y_vals = vec![0.0, 10.0, 20.0, 30.0];
342        
343        assert_eq!(interpolate(&x_vals, &y_vals, 0.5), 5.0);
344        assert_eq!(interpolate(&x_vals, &y_vals, 1.5), 15.0);
345        assert_eq!(interpolate(&x_vals, &y_vals, 2.5), 25.0);
346        
347        // Test boundary conditions
348        assert_eq!(interpolate(&x_vals, &y_vals, -1.0), 0.0);  // Below range
349        assert_eq!(interpolate(&x_vals, &y_vals, 4.0), 30.0);  // Above range
350    }
351    
352    #[test]
353    fn test_find_closest_sample_index() {
354        let samples = vec![
355            TrajectorySample {
356                distance_m: 0.0,
357                drop_m: 0.0,
358                wind_drift_m: 0.0,
359                velocity_mps: 100.0,
360                energy_j: 1000.0,
361                time_s: 0.0,
362                flags: Vec::new(),
363            },
364            TrajectorySample {
365                distance_m: 10.0,
366                drop_m: -1.0,
367                wind_drift_m: 0.1,
368                velocity_mps: 95.0,
369                energy_j: 950.0,
370                time_s: 0.1,
371                flags: Vec::new(),
372            },
373            TrajectorySample {
374                distance_m: 20.0,
375                drop_m: -4.0,
376                wind_drift_m: 0.2,
377                velocity_mps: 90.0,
378                energy_j: 900.0,
379                time_s: 0.2,
380                flags: Vec::new(),
381            },
382        ];
383        
384        assert_eq!(find_closest_sample_index(&samples, 5.0), Some(0));
385        assert_eq!(find_closest_sample_index(&samples, 12.0), Some(1));
386        assert_eq!(find_closest_sample_index(&samples, 18.0), Some(2));
387    }
388    
389    #[test]
390    fn test_detect_zero_crossings() {
391        let mut samples = vec![
392            TrajectorySample {
393                distance_m: 0.0,
394                drop_m: 1.0,  // Positive
395                wind_drift_m: 0.0,
396                velocity_mps: 100.0,
397                energy_j: 1000.0,
398                time_s: 0.0,
399                flags: Vec::new(),
400            },
401            TrajectorySample {
402                distance_m: 10.0,
403                drop_m: -0.5,  // Negative - crossing here
404                wind_drift_m: 0.1,
405                velocity_mps: 95.0,
406                energy_j: 950.0,
407                time_s: 0.1,
408                flags: Vec::new(),
409            },
410            TrajectorySample {
411                distance_m: 20.0,
412                drop_m: -2.0,  // Still negative
413                wind_drift_m: 0.2,
414                velocity_mps: 90.0,
415                energy_j: 900.0,
416                time_s: 0.2,
417                flags: Vec::new(),
418            },
419        ];
420        
421        detect_zero_crossings(&mut samples, 1e-6);
422        
423        // Should have a zero crossing flag at index 1
424        assert!(!samples[0].flags.contains(&TrajectoryFlag::ZeroCrossing));
425        assert!(samples[1].flags.contains(&TrajectoryFlag::ZeroCrossing));
426        assert!(!samples[2].flags.contains(&TrajectoryFlag::ZeroCrossing));
427    }
428    
429    #[test]
430    fn test_sample_trajectory_basic() {
431        // Create simple test trajectory data
432        // Note: x=downrange, y=vertical, z=lateral (STANDARD BALLISTICS CONVENTION)
433        let trajectory_data = TrajectoryData {
434            times: vec![0.0, 1.0, 2.0],
435            positions: vec![
436                Vector3::new(0.0, 0.0, 0.0),      // x=0 (start), y=0 (vertical), z=0 (lateral)
437                Vector3::new(100.0, 10.0, 1.0),   // x=100 (mid - apex region), y=10, z=1 (drift)
438                Vector3::new(200.0, 5.0, 2.0),    // x=200 (end), y=5, z=2 (drift)
439            ],
440            velocities: vec![
441                Vector3::new(1.0, 10.0, 100.0),
442                Vector3::new(1.0, 5.0, 95.0),
443                Vector3::new(1.0, 0.0, 90.0),
444            ],
445            transonic_distances: vec![150.0],
446        };
447        
448        let outputs = TrajectoryOutputs {
449            target_distance_horiz_m: 200.0,
450            target_vertical_height_m: 0.0,
451            time_of_flight_s: 2.0,
452            max_ord_dist_horiz_m: 100.0,
453        };
454        
455        let samples = sample_trajectory(&trajectory_data, &outputs, 50.0, 0.1);
456        
457        // Should have samples at 0, 50, 100, 150, 200 meters
458        assert_eq!(samples.len(), 5);
459        assert_eq!(samples[0].distance_m, 0.0);
460        assert_eq!(samples[1].distance_m, 50.0);
461        assert_eq!(samples[2].distance_m, 100.0);
462        assert_eq!(samples[3].distance_m, 150.0);
463        assert_eq!(samples[4].distance_m, 200.0);
464        
465        // Check that interpolation is working
466        assert!(samples[1].velocity_mps > 90.0 && samples[1].velocity_mps < 100.0);
467        
468        // Check flags
469        assert!(samples[2].flags.contains(&TrajectoryFlag::Apex)); // At apex distance
470        assert!(samples[3].flags.contains(&TrajectoryFlag::MachTransition)); // At transonic distance
471    }
472}