use nalgebra::Vector3;
use std::collections::HashSet;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum TrajectoryFlag {
ZeroCrossing,
MachTransition,
Apex,
}
impl TrajectoryFlag {
pub fn to_string(&self) -> String {
match self {
TrajectoryFlag::ZeroCrossing => "zero_crossing".to_string(),
TrajectoryFlag::MachTransition => "mach_transition".to_string(),
TrajectoryFlag::Apex => "apex".to_string(),
}
}
}
#[derive(Debug, Clone)]
pub struct TrajectorySample {
pub distance_m: f64,
pub drop_m: f64,
pub wind_drift_m: f64,
pub velocity_mps: f64,
pub energy_j: f64,
pub time_s: f64,
pub flags: Vec<TrajectoryFlag>,
}
#[derive(Debug, Clone)]
pub struct TrajectoryData {
pub times: Vec<f64>,
pub positions: Vec<Vector3<f64>>, pub velocities: Vec<Vector3<f64>>, pub transonic_distances: Vec<f64>, }
#[derive(Debug, Clone)]
pub struct TrajectoryOutputs {
pub target_distance_horiz_m: f64,
pub target_vertical_height_m: f64,
pub time_of_flight_s: f64,
pub max_ord_dist_horiz_m: f64,
pub sight_height_m: f64,
}
pub fn sample_trajectory(
trajectory_data: &TrajectoryData,
outputs: &TrajectoryOutputs,
step_m: f64,
mass_kg: f64,
) -> Vec<TrajectorySample> {
let step_size = if step_m <= 0.0 {
return Vec::new();
} else if step_m < 0.1 {
0.1
} else {
step_m
};
let max_dist = outputs.target_distance_horiz_m;
if max_dist < 1e-9 {
return Vec::new();
}
let x_vals: Vec<f64> = trajectory_data.positions.iter().map(|p| p.x).collect();
let y_vals: Vec<f64> = trajectory_data.positions.iter().map(|p| p.y).collect();
let z_vals: Vec<f64> = trajectory_data.positions.iter().map(|p| p.z).collect();
let speeds: Vec<f64> = trajectory_data
.velocities
.iter()
.map(|v| v.norm())
.collect();
let energies: Vec<f64> = speeds
.iter()
.map(|&speed| 0.5 * mass_kg * speed * speed)
.collect();
let num_steps = (max_dist / step_size).ceil() as usize + 1;
let distances: Vec<f64> = (0..num_steps)
.map(|i| i as f64 * step_size)
.filter(|&d| d <= max_dist + 0.1) .collect();
let mut samples = Vec::with_capacity(distances.len());
for &distance in &distances {
let y_interp = interpolate(&z_vals, &y_vals, distance); let wind_drift = interpolate(&z_vals, &x_vals, distance); let velocity = interpolate(&z_vals, &speeds, distance); let time = interpolate(&z_vals, &trajectory_data.times, distance); let energy = interpolate(&z_vals, &energies, distance);
let los_y = outputs.sight_height_m
+ (outputs.target_vertical_height_m - outputs.sight_height_m) * distance / max_dist;
let drop = los_y - y_interp;
samples.push(TrajectorySample {
distance_m: distance,
drop_m: drop,
wind_drift_m: wind_drift,
velocity_mps: velocity,
energy_j: energy,
time_s: time,
flags: Vec::new(), });
}
add_trajectory_flags(&mut samples, &trajectory_data.transonic_distances, max_dist);
samples
}
fn interpolate(x_vals: &[f64], y_vals: &[f64], x: f64) -> f64 {
if x_vals.is_empty() || y_vals.is_empty() {
return 0.0;
}
if x_vals.len() != y_vals.len() {
return 0.0;
}
if x <= x_vals[0] {
return y_vals[0];
}
if x >= x_vals[x_vals.len() - 1] {
return y_vals[y_vals.len() - 1];
}
let mut left = 0;
let mut right = x_vals.len() - 1;
while right - left > 1 {
let mid = (left + right) / 2;
if x_vals[mid] <= x {
left = mid;
} else {
right = mid;
}
}
let x1 = x_vals[left];
let x2 = x_vals[right];
let y1 = y_vals[left];
let y2 = y_vals[right];
if (x2 - x1).abs() < f64::EPSILON {
return y1;
}
y1 + (y2 - y1) * (x - x1) / (x2 - x1)
}
fn add_trajectory_flags(
samples: &mut [TrajectorySample],
transonic_distances: &[f64],
target_distance_input_m: f64,
) {
let tolerance = 1e-6;
detect_zero_crossings(samples, tolerance);
for &transonic_dist in transonic_distances {
if let Some(idx) = find_closest_sample_index(samples, transonic_dist) {
samples[idx].flags.push(TrajectoryFlag::MachTransition);
}
}
if samples.len() > 2 {
let target_distance_m = target_distance_input_m;
let mut min_drop = f64::INFINITY;
let mut apex_idx = 1;
for i in 1..samples.len() {
if samples[i].distance_m > target_distance_m {
break;
}
if samples[i].drop_m < min_drop {
min_drop = samples[i].drop_m;
apex_idx = i;
}
}
samples[apex_idx].flags.push(TrajectoryFlag::Apex);
}
}
fn detect_zero_crossings(samples: &mut [TrajectorySample], tolerance: f64) {
if samples.len() < 2 {
return;
}
let drops: Vec<f64> = samples.iter().map(|s| s.drop_m).collect();
for i in 0..(drops.len() - 1) {
let current = drops[i];
let next = drops[i + 1];
let crosses_zero = (current < -tolerance && next >= -tolerance)
|| (current > tolerance && next <= tolerance);
if crosses_zero {
samples[i + 1].flags.push(TrajectoryFlag::ZeroCrossing);
}
}
for (i, &drop) in drops.iter().enumerate() {
if drop.abs() <= tolerance {
samples[i].flags.push(TrajectoryFlag::ZeroCrossing);
}
}
for sample in samples.iter_mut() {
let mut unique_flags = Vec::new();
let mut seen = HashSet::new();
for flag in &sample.flags {
if seen.insert(flag.clone()) {
unique_flags.push(flag.clone());
}
}
sample.flags = unique_flags;
}
}
fn find_closest_sample_index(samples: &[TrajectorySample], target_distance: f64) -> Option<usize> {
if samples.is_empty() {
return None;
}
let distances: Vec<f64> = samples.iter().map(|s| s.distance_m).collect();
let mut left = 0;
let mut right = distances.len();
while left < right {
let mid = (left + right) / 2;
if distances[mid] < target_distance {
left = mid + 1;
} else {
right = mid;
}
}
let mut best_idx = left.min(distances.len() - 1);
if left > 0 {
let left_dist = (distances[left - 1] - target_distance).abs();
let right_dist = (distances[best_idx] - target_distance).abs();
if left_dist <= right_dist {
best_idx = left - 1;
}
}
Some(best_idx)
}
pub fn trajectory_samples_to_dicts(samples: &[TrajectorySample]) -> Vec<TrajectoryDict> {
samples
.iter()
.map(|sample| TrajectoryDict {
distance_m: sample.distance_m,
drop_m: sample.drop_m,
wind_drift_m: sample.wind_drift_m,
velocity_mps: sample.velocity_mps,
energy_j: sample.energy_j,
time_s: sample.time_s,
flags: sample.flags.iter().map(|f| f.to_string()).collect(),
})
.collect()
}
#[derive(Debug, Clone)]
pub struct TrajectoryDict {
pub distance_m: f64,
pub drop_m: f64,
pub wind_drift_m: f64,
pub velocity_mps: f64,
pub energy_j: f64,
pub time_s: f64,
pub flags: Vec<String>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_interpolate() {
let x_vals = vec![0.0, 1.0, 2.0, 3.0];
let y_vals = vec![0.0, 10.0, 20.0, 30.0];
assert_eq!(interpolate(&x_vals, &y_vals, 0.5), 5.0);
assert_eq!(interpolate(&x_vals, &y_vals, 1.5), 15.0);
assert_eq!(interpolate(&x_vals, &y_vals, 2.5), 25.0);
assert_eq!(interpolate(&x_vals, &y_vals, -1.0), 0.0); assert_eq!(interpolate(&x_vals, &y_vals, 4.0), 30.0); }
#[test]
fn test_find_closest_sample_index() {
let samples = vec![
TrajectorySample {
distance_m: 0.0,
drop_m: 0.0,
wind_drift_m: 0.0,
velocity_mps: 100.0,
energy_j: 1000.0,
time_s: 0.0,
flags: Vec::new(),
},
TrajectorySample {
distance_m: 10.0,
drop_m: -1.0,
wind_drift_m: 0.1,
velocity_mps: 95.0,
energy_j: 950.0,
time_s: 0.1,
flags: Vec::new(),
},
TrajectorySample {
distance_m: 20.0,
drop_m: -4.0,
wind_drift_m: 0.2,
velocity_mps: 90.0,
energy_j: 900.0,
time_s: 0.2,
flags: Vec::new(),
},
];
assert_eq!(find_closest_sample_index(&samples, 5.0), Some(0));
assert_eq!(find_closest_sample_index(&samples, 12.0), Some(1));
assert_eq!(find_closest_sample_index(&samples, 18.0), Some(2));
}
#[test]
fn test_detect_zero_crossings() {
let mut samples = vec![
TrajectorySample {
distance_m: 0.0,
drop_m: 1.0, wind_drift_m: 0.0,
velocity_mps: 100.0,
energy_j: 1000.0,
time_s: 0.0,
flags: Vec::new(),
},
TrajectorySample {
distance_m: 10.0,
drop_m: -0.5, wind_drift_m: 0.1,
velocity_mps: 95.0,
energy_j: 950.0,
time_s: 0.1,
flags: Vec::new(),
},
TrajectorySample {
distance_m: 20.0,
drop_m: -2.0, wind_drift_m: 0.2,
velocity_mps: 90.0,
energy_j: 900.0,
time_s: 0.2,
flags: Vec::new(),
},
];
detect_zero_crossings(&mut samples, 1e-6);
assert!(!samples[0].flags.contains(&TrajectoryFlag::ZeroCrossing));
assert!(samples[1].flags.contains(&TrajectoryFlag::ZeroCrossing));
assert!(!samples[2].flags.contains(&TrajectoryFlag::ZeroCrossing));
}
#[test]
fn test_sample_trajectory_basic() {
let trajectory_data = TrajectoryData {
times: vec![0.0, 1.0, 2.0],
positions: vec![
Vector3::new(0.0, 0.0, 0.0), Vector3::new(1.0, 10.0, 100.0), Vector3::new(2.0, 5.0, 200.0), ],
velocities: vec![
Vector3::new(1.0, 10.0, 100.0),
Vector3::new(1.0, 5.0, 95.0),
Vector3::new(1.0, 0.0, 90.0),
],
transonic_distances: vec![150.0],
};
let outputs = TrajectoryOutputs {
target_distance_horiz_m: 200.0,
target_vertical_height_m: 0.0,
time_of_flight_s: 2.0,
max_ord_dist_horiz_m: 100.0,
sight_height_m: 0.0, };
let samples = sample_trajectory(&trajectory_data, &outputs, 50.0, 0.1);
assert_eq!(samples.len(), 5);
assert_eq!(samples[0].distance_m, 0.0);
assert_eq!(samples[1].distance_m, 50.0);
assert_eq!(samples[2].distance_m, 100.0);
assert_eq!(samples[3].distance_m, 150.0);
assert_eq!(samples[4].distance_m, 200.0);
assert!(samples[1].velocity_mps > 90.0 && samples[1].velocity_mps < 100.0);
assert!(samples[2].flags.contains(&TrajectoryFlag::Apex)); assert!(samples[3].flags.contains(&TrajectoryFlag::MachTransition)); }
}