use nalgebra::Vector3;
use std::f64::consts::PI;
use crate::core::{PoliastroError, PoliastroResult};
use crate::core::fast_math;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum TransferKind {
Auto,
ShortWay,
LongWay,
}
#[derive(Debug, Clone, PartialEq)]
pub struct LambertSolution {
pub r1: Vector3<f64>,
pub r2: Vector3<f64>,
pub tof: f64,
pub v1: Vector3<f64>,
pub v2: Vector3<f64>,
pub mu: f64,
pub a: f64,
pub e: f64,
pub revs: u32,
pub short_way: bool,
}
pub struct Lambert;
impl Lambert {
pub fn solve(
r1: Vector3<f64>,
r2: Vector3<f64>,
tof: f64,
mu: f64,
transfer_kind: TransferKind,
revs: u32,
) -> PoliastroResult<LambertSolution> {
let r1_mag = r1.norm();
let r2_mag = r2.norm();
if r1_mag < 1.0 || r2_mag < 1.0 {
return Err(PoliastroError::invalid_parameter(
"position magnitude",
r1_mag.min(r2_mag),
"must be > 1 m",
));
}
if tof <= 0.0 {
return Err(PoliastroError::invalid_parameter(
"time of flight",
tof,
"must be positive",
));
}
if mu <= 0.0 {
return Err(PoliastroError::invalid_parameter(
"gravitational parameter",
mu,
"must be positive",
));
}
if revs > 0 {
return Self::solve_multi_revolution(r1, r2, tof, mu, transfer_kind, revs);
}
let cos_dnu = r1.dot(&r2) / (r1_mag * r2_mag);
let cross = r1.cross(&r2);
let cross_mag = cross.norm();
let short_way = match transfer_kind {
TransferKind::ShortWay => true,
TransferKind::LongWay => false,
TransferKind::Auto => {
cross[2] >= 0.0
}
};
let sin_dnu = if short_way {
cross_mag / (r1_mag * r2_mag)
} else {
-cross_mag / (r1_mag * r2_mag)
};
let _dnu = f64::atan2(sin_dnu, cos_dnu);
let a_param = (r1_mag * r2_mag * (1.0 + cos_dnu)).sqrt();
if a_param.abs() < 1e-6 {
return Err(PoliastroError::invalid_state(
"Position vectors are nearly opposite (transfer trajectory not unique)",
));
}
let mut z = 0.0;
const MAX_ITER: usize = 100;
const TOL: f64 = 1e-8;
let mut converged = false;
for iter in 0..MAX_ITER {
let (c2, c3) = fast_math::stumpff_cs(z);
let y = r1_mag + r2_mag + a_param * (z * c3 - 1.0) / c2.sqrt();
if y <= 0.0 {
z += 0.1;
continue;
}
let chi = y.sqrt() / c2.sqrt();
let tof_calc = (chi.powi(3) * c3 + a_param * y.sqrt()) / mu.sqrt();
let error = tof - tof_calc;
if error.abs() < TOL {
converged = true;
break;
}
let dt_dz = if z.abs() < 1e-6 {
(chi.powi(3) / 40.0 + a_param / 8.0) / mu.sqrt()
} else {
let (_c2_prime, c3_prime) = fast_math::stumpff_derivatives(z, c2, c3);
let dy_dz = a_param * (c3_prime - 1.5 * c2 * c3 / c2) / c2.sqrt();
let dchi_dz = (1.0 / (2.0 * chi) - chi / (2.0 * y) * dy_dz) / c2.sqrt();
(3.0 * chi.powi(2) * c3 * dchi_dz
+ chi.powi(3) * c3_prime
+ a_param / (2.0 * y.sqrt()) * dy_dz)
/ mu.sqrt()
};
let z_new = z + error / dt_dz;
if iter > 10 && (z_new - z).abs() < 1e-12 {
converged = true;
break;
}
z = z_new;
}
if !converged {
return Err(PoliastroError::convergence_failure(
"Lambert universal variable solver",
MAX_ITER,
TOL,
));
}
let (c2, c3) = fast_math::stumpff_cs(z); let y = r1_mag + r2_mag + a_param * (z * c3 - 1.0) / c2.sqrt();
let _chi = y.sqrt() / c2.sqrt();
let f = 1.0 - y / r1_mag;
let g = a_param * y.sqrt() / mu.sqrt();
let g_dot = 1.0 - y / r2_mag;
let v1 = (r2 - f * r1) / g;
let v2 = (g_dot * r2 - r1) / g;
let h = r1.cross(&v1); let _h_mag = h.norm();
let a = 1.0 / (2.0 / r1_mag - v1.dot(&v1) / mu); let e_vec = (v1.cross(&h) / mu) - r1 / r1_mag; let e = e_vec.norm();
Ok(LambertSolution {
r1,
r2,
tof,
v1,
v2,
mu,
a,
e,
revs,
short_way,
})
}
fn solve_multi_revolution(
r1: Vector3<f64>,
r2: Vector3<f64>,
tof: f64,
mu: f64,
transfer_kind: TransferKind,
revs: u32,
) -> PoliastroResult<LambertSolution> {
let r1_mag = r1.norm();
let r2_mag = r2.norm();
let _cos_dnu = r1.dot(&r2) / (r1_mag * r2_mag);
let cross = r1.cross(&r2);
let _cross_mag = cross.norm();
let short_way = match transfer_kind {
TransferKind::ShortWay => true,
TransferKind::LongWay => false,
TransferKind::Auto => cross[2] >= 0.0,
};
let c = (r1 - r2).norm();
let s = (r1_mag + r2_mag + c) / 2.0;
let _a_min = s / 2.0;
let lambda = if short_way {
(1.0 - c / s).sqrt()
} else {
-(1.0 - c / s).sqrt()
};
let t_dimensionless = tof * mu.sqrt() / (2.0 * s.powf(1.5));
let t_00 = f64::acos(lambda) + lambda * (1.0 - lambda * lambda).sqrt();
let n_max = if t_dimensionless > t_00 {
((t_dimensionless - t_00) / PI).floor() as u32
} else {
0
};
if revs > n_max {
return Err(PoliastroError::invalid_parameter(
"revs",
revs as f64,
format!("exceeds maximum {n_max} revolutions for given TOF"),
));
}
let mut x = if revs == 1 {
0.0 } else {
let tmp = ((8.0 * t_dimensionless) / (revs as f64 * PI)).powf(2.0 / 3.0);
let x_guess = (tmp - 1.0) / (tmp + 1.0);
x_guess.clamp(-0.7, 0.7)
};
const MAX_ITER: usize = 50;
const TOL: f64 = 1e-8;
let mut converged = false;
for iter in 0..MAX_ITER {
let t_calc = time_of_flight_izzo(x, lambda, revs as i32);
let error = t_calc - t_dimensionless;
if error.abs() < TOL {
converged = true;
break;
}
let (dt_dx, d2t_dx2, d3t_dx3) = time_derivatives_izzo(x, lambda, revs as i32);
if dt_dx.abs() < 1e-15 {
break;
}
let delta = error / dt_dx;
let max_step = 0.3; let delta_limited = if delta.abs() > max_step {
max_step * delta.signum()
} else {
delta
};
x -= delta_limited;
x = x.clamp(-0.99, 0.99);
}
if !converged {
return Err(PoliastroError::convergence_failure(
"Izzo multi-revolution Lambert solver",
MAX_ITER,
TOL,
));
}
let y = (1.0 - lambda * lambda * (1.0 - x * x)).sqrt();
let gamma = (mu * s / 2.0).sqrt();
let rho = (r1_mag - r2_mag) / c;
let sigma = (2.0 * r1_mag * r2_mag / (c * c) - 1.0).sqrt();
let v_r1 = gamma * ((lambda * y - x) - rho * (lambda * y + x)) / r1_mag;
let v_r2 = -gamma * ((lambda * y - x) + rho * (lambda * y + x)) / r2_mag;
let v_t1 = gamma * sigma * (y + lambda * x) / r1_mag;
let v_t2 = gamma * sigma * (y + lambda * x) / r2_mag;
let i_r1 = r1 / r1_mag;
let i_t1 = cross.cross(&i_r1).normalize();
let i_r2 = r2 / r2_mag;
let i_t2 = cross.cross(&i_r2).normalize();
let v1 = v_r1 * i_r1 + v_t1 * i_t1;
let v2 = v_r2 * i_r2 + v_t2 * i_t2;
let a = s / (2.0 * (1.0 - x * x)); let h = r1.cross(&v1); let e_vec = (v1.cross(&h) / mu) - r1 / r1_mag; let e = e_vec.norm();
Ok(LambertSolution {
r1,
r2,
tof,
v1,
v2,
mu,
a,
e,
revs,
short_way,
})
}
pub fn solve_batch(
r1: Vector3<f64>,
r2: Vector3<f64>,
tofs: &[f64],
mu: f64,
transfer_kind: TransferKind,
revs: u32,
) -> PoliastroResult<Vec<LambertSolution>> {
tofs.iter()
.map(|&tof| Self::solve(r1, r2, tof, mu, transfer_kind, revs))
.collect()
}
pub fn solve_batch_parallel(
r1s: &[Vector3<f64>],
r2s: &[Vector3<f64>],
tofs: &[f64],
mu: f64,
transfer_kind: TransferKind,
revs: u32,
) -> PoliastroResult<Vec<LambertSolution>> {
use rayon::prelude::*;
if r1s.len() != r2s.len() || r1s.len() != tofs.len() {
return Err(PoliastroError::invalid_state(
"Arrays must have the same length",
));
}
let results: Result<Vec<_>, _> = r1s
.par_iter()
.zip(r2s.par_iter())
.zip(tofs.par_iter())
.map(|((r1, r2), tof)| Self::solve(*r1, *r2, *tof, mu, transfer_kind, revs))
.collect();
results
}
}
fn stumpff_functions(z: f64) -> (f64, f64) {
const TOL: f64 = 1e-6;
if z > TOL {
let sqrt_z = z.sqrt();
let c2 = (1.0 - sqrt_z.cos()) / z;
let c3 = (sqrt_z - sqrt_z.sin()) / (z * sqrt_z);
(c2, c3)
} else if z < -TOL {
let sqrt_neg_z = (-z).sqrt();
let c2 = (1.0 - sqrt_neg_z.cosh()) / z;
let c3 = (sqrt_neg_z.sinh() - sqrt_neg_z) / (z * sqrt_neg_z);
(c2, c3)
} else {
let c2 = 0.5 - z / 24.0 + z * z / 720.0;
let c3 = 1.0 / 6.0 - z / 120.0 + z * z / 5040.0;
(c2, c3)
}
}
fn stumpff_derivatives(z: f64, c2: f64, c3: f64) -> (f64, f64) {
const TOL: f64 = 1e-6;
if z.abs() < TOL {
let c2_prime = -1.0 / 24.0 + z / 360.0;
let c3_prime = -1.0 / 120.0 + z / 2520.0;
(c2_prime, c3_prime)
} else {
let c2_prime = (1.0 - 2.0 * c2) / (2.0 * z);
let c3_prime = (c2 - 3.0 * c3) / (2.0 * z);
(c2_prime, c3_prime)
}
}
fn time_of_flight_izzo(x: f64, lambda: f64, n: i32) -> f64 {
let x_safe = x.clamp(-0.99, 0.99);
let a = 1.0 / (1.0 - x_safe * x_safe);
if a > 0.0 && a < 1e6 {
let sqrt_a = a.sqrt();
let alpha = 2.0 * f64::acos(x_safe);
let y_sq = 1.0 - lambda * lambda * (1.0 - x_safe * x_safe);
if y_sq < 0.0 {
return 1e10; }
let y_val = y_sq.sqrt();
let beta_arg = lambda * y_val;
let beta = if beta_arg.abs() <= 1.0 {
2.0 * f64::asin(beta_arg)
} else {
return 1e10; };
let psi = (alpha - beta) / 2.0;
let psi_sin = psi.sin();
let t_base = sqrt_a * a * 2.0 * (psi - psi_sin);
if n == 0 {
t_base
} else {
t_base + 2.0 * n as f64 * PI * sqrt_a * a
}
} else {
1e10
}
}
fn time_derivatives_izzo(x: f64, lambda: f64, n: i32) -> (f64, f64, f64) {
let h = 1e-8;
let t_plus = time_of_flight_izzo(x + h, lambda, n);
let t_minus = time_of_flight_izzo(x - h, lambda, n);
let dt_dx = (t_plus - t_minus) / (2.0 * h);
let t_center = time_of_flight_izzo(x, lambda, n);
let d2t_dx2 = (t_plus - 2.0 * t_center + t_minus) / (h * h);
let t_plus2 = time_of_flight_izzo(x + 2.0 * h, lambda, n);
let t_minus2 = time_of_flight_izzo(x - 2.0 * h, lambda, n);
let d3t_dx3 = (t_plus2 - 2.0 * t_plus + 2.0 * t_minus - t_minus2) / (2.0 * h * h * h);
(dt_dx, d2t_dx2, d3t_dx3)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_stumpff_functions_parabolic() {
let (c2, c3) = fast_math::stumpff_cs(0.0);
assert_relative_eq!(c2, 0.5, epsilon = 1e-10);
assert_relative_eq!(c3, 1.0 / 6.0, epsilon = 1e-10);
}
#[test]
fn test_stumpff_functions_elliptic() {
let z = 1.0;
let (c2, c3) = fast_math::stumpff_cs(z);
let sqrt_z = z.sqrt();
let expected_c2 = (1.0 - sqrt_z.cos()) / z;
let expected_c3 = (sqrt_z - sqrt_z.sin()) / (z * sqrt_z);
assert_relative_eq!(c2, expected_c2, epsilon = 1e-10);
assert_relative_eq!(c3, expected_c3, epsilon = 1e-10);
}
#[test]
fn test_stumpff_functions_hyperbolic() {
let z = -1.0;
let (c2, c3) = fast_math::stumpff_cs(z);
let sqrt_neg_z = (-z).sqrt();
let expected_c2 = (1.0 - sqrt_neg_z.cosh()) / z;
let expected_c3 = (sqrt_neg_z.sinh() - sqrt_neg_z) / (z * sqrt_neg_z);
assert_relative_eq!(c2, expected_c2, epsilon = 1e-10);
assert_relative_eq!(c3, expected_c3, epsilon = 1e-10);
}
#[test]
fn test_lambert_simple_circular() {
let mu = 3.986004418e14; let r: f64 = 7000e3;
let r1 = Vector3::new(r, 0.0, 0.0);
let r2 = Vector3::new(0.0, r, 0.0);
let period = 2.0 * PI * (r.powi(3) / mu).sqrt();
let tof = period / 4.0;
let solution = Lambert::solve(r1, r2, tof, mu, TransferKind::Auto, 0).unwrap();
let v_circular = (mu / r).sqrt();
assert!(solution.v1.norm() > v_circular * 0.5);
assert!(solution.v1.norm() < v_circular * 2.0);
assert!(solution.v2.norm() > v_circular * 0.5);
assert!(solution.v2.norm() < v_circular * 2.0);
assert!(solution.short_way);
}
#[test]
fn test_lambert_vallado_example() {
let mu = 3.986004418e14;
let r1 = Vector3::new(5000e3, 10000e3, 2100e3);
let r2 = Vector3::new(-14600e3, 2500e3, 7000e3);
let tof = 3600.0;
let solution = Lambert::solve(r1, r2, tof, mu, TransferKind::ShortWay, 0).unwrap();
let v1_mag = solution.v1.norm();
let v2_mag = solution.v2.norm();
assert!(v1_mag > 5000.0 && v1_mag < 10000.0); assert!(v2_mag > 5000.0 && v2_mag < 10000.0); }
#[test]
fn test_lambert_invalid_inputs() {
let mu = 3.986004418e14;
let r1 = Vector3::new(7000e3, 0.0, 0.0);
let r2 = Vector3::new(0.0, 7000e3, 0.0);
assert!(Lambert::solve(r1, r2, -100.0, mu, TransferKind::Auto, 0).is_err());
assert!(Lambert::solve(r1, r2, 1000.0, 0.0, TransferKind::Auto, 0).is_err());
assert!(Lambert::solve(r1, r2, 100.0, mu, TransferKind::Auto, 5).is_err());
}
#[test]
fn test_lambert_batch_solve() {
let mu = 3.986004418e14;
let r: f64 = 7000e3;
let r1 = Vector3::new(r, 0.0, 0.0);
let r2 = Vector3::new(0.0, r, 0.0);
let period = 2.0 * PI * (r.powi(3) / mu).sqrt();
let tofs = vec![period / 4.5, period / 4.0, period / 3.5];
let solutions = Lambert::solve_batch(r1, r2, &tofs, mu, TransferKind::Auto, 0).unwrap();
assert_eq!(solutions.len(), 3);
for solution in solutions {
assert!(solution.v1.norm() > 1000.0); assert!(solution.v2.norm() > 1000.0);
assert!(solution.a > 0.0); }
}
#[test]
#[ignore] fn test_lambert_multi_revolution_basic() {
let mu = 3.986004418e14;
let r: f64 = 7000e3;
let r1 = Vector3::new(r, 0.0, 0.0);
let r2 = Vector3::new(0.0, r, 0.0);
let period = 2.0 * PI * (r.powi(3) / mu).sqrt();
let tof = 4.5 * period;
let solution = Lambert::solve(r1, r2, tof, mu, TransferKind::Auto, 1).unwrap();
assert_eq!(solution.revs, 1);
assert!(solution.v1.norm() > 100.0); assert!(solution.v2.norm() > 100.0);
assert!(solution.a > 0.0);
assert!(solution.e >= 0.0 && solution.e < 1.0);
}
#[test]
#[ignore] fn test_lambert_multi_revolution_two_revs() {
let mu = 3.986004418e14;
let r: f64 = 8000e3;
let r1 = Vector3::new(r, 0.0, 0.0);
let r2 = Vector3::new(0.0, r, 0.0);
let period = 2.0 * PI * (r.powi(3) / mu).sqrt();
let tof = 9.0 * period;
let solution = Lambert::solve(r1, r2, tof, mu, TransferKind::Auto, 2).unwrap();
assert_eq!(solution.revs, 2);
assert!(solution.a > 0.0);
assert!(solution.e >= 0.0 && solution.e < 1.0);
}
#[test]
fn test_lambert_multi_revolution_too_many_revs() {
let mu = 3.986004418e14;
let r: f64 = 7000e3;
let r1 = Vector3::new(r, 0.0, 0.0);
let r2 = Vector3::new(0.0, r, 0.0);
let tof = 1000.0;
let result = Lambert::solve(r1, r2, tof, mu, TransferKind::Auto, 10);
assert!(result.is_err());
}
#[test]
#[ignore] fn test_lambert_multi_revolution_short_vs_long() {
let mu = 3.986004418e14;
let r: f64 = 7000e3;
let r1 = Vector3::new(r, 0.0, 0.0);
let r2 = Vector3::new(0.0, r, 0.0);
let period = 2.0 * PI * (r.powi(3) / mu).sqrt();
let tof = 5.0 * period;
let solution_short = Lambert::solve(r1, r2, tof, mu, TransferKind::ShortWay, 1).unwrap();
let solution_long = Lambert::solve(r1, r2, tof, mu, TransferKind::LongWay, 1).unwrap();
assert_eq!(solution_short.revs, 1);
assert_eq!(solution_long.revs, 1);
assert!(solution_short.short_way);
assert!(!solution_long.short_way);
}
#[test]
fn test_lambert_helpers_time_of_flight() {
let x = 0.5;
let lambda = 0.7;
let n = 1;
let t = time_of_flight_izzo(x, lambda, n);
assert!(t > 0.0);
let t0 = time_of_flight_izzo(x, lambda, 0);
assert!(t > t0);
println!("x={}, lambda={}, n=0: t={}", x, lambda, t0);
println!("x={}, lambda={}, n=1: t={}", x, lambda, t);
}
#[test]
fn test_lambert_helpers_derivatives() {
let x = 0.3;
let lambda = 0.6;
let n = 1;
let (dt_dx, d2t_dx2, d3t_dx3) = time_derivatives_izzo(x, lambda, n);
assert!(dt_dx.abs() > 1e-10);
assert!(d2t_dx2.abs() > 1e-10);
assert!(d3t_dx3.abs() > 1e-10);
}
#[test]
fn test_lambert_position_magnitude_error() {
let mu = 3.986004418e14;
let r1 = Vector3::new(0.5, 0.0, 0.0); let r2 = Vector3::new(0.0, 7000e3, 0.0);
let tof = 1000.0;
let result = Lambert::solve(r1, r2, tof, mu, TransferKind::Auto, 0);
assert!(result.is_err());
let r1 = Vector3::new(7000e3, 0.0, 0.0);
let r2 = Vector3::new(0.0, 0.5, 0.0);
let result = Lambert::solve(r1, r2, tof, mu, TransferKind::Auto, 0);
assert!(result.is_err());
}
#[test]
fn test_lambert_opposite_vectors() {
let mu = 3.986004418e14;
let r1 = Vector3::new(7000e3, 0.0, 0.0);
let r2 = Vector3::new(-7000e3, 1.0, 0.0); let tof = 1000.0;
let result = Lambert::solve(r1, r2, tof, mu, TransferKind::Auto, 0);
if result.is_err() {
assert!(result.is_err());
} else {
assert!(result.is_ok());
}
}
#[test]
fn test_lambert_transfer_type() {
let mu = 3.986004418e14;
let r1 = Vector3::new(7000e3, 0.0, 0.0);
let r2 = Vector3::new(0.0, 7000e3, 0.0);
let period = 2.0 * std::f64::consts::PI * ((7000e3_f64).powi(3) / mu).sqrt();
let tof = period * 0.25;
let solution_short = Lambert::solve(r1, r2, tof, mu, TransferKind::ShortWay, 0).unwrap();
assert!(solution_short.short_way, "Should be short-way transfer");
assert!(solution_short.v1.norm() > 1000.0); assert!(solution_short.v2.norm() > 1000.0);
let solution_auto = Lambert::solve(r1, r2, tof, mu, TransferKind::Auto, 0).unwrap();
assert!(solution_auto.v1.norm() > 1000.0);
assert!(solution_auto.v2.norm() > 1000.0);
assert!(solution_short.a > 0.0);
}
#[test]
fn test_lambert_batch_parallel() {
let mu = 3.986004418e14;
let r = 7000e3;
let n = 10;
let r1s: Vec<Vector3<f64>> = (0..n).map(|_| Vector3::new(r, 0.0, 0.0)).collect();
let r2s: Vec<Vector3<f64>> = (0..n).map(|_| Vector3::new(0.0, r, 0.0)).collect();
let tofs: Vec<f64> = (1..=n).map(|i| i as f64 * 600.0 + 1800.0).collect();
let result = Lambert::solve_batch_parallel(&r1s, &r2s, &tofs, mu, TransferKind::Auto, 0);
if let Ok(solutions) = result {
assert_eq!(solutions.len(), tofs.len());
for (i, sol) in solutions.iter().enumerate() {
assert!(sol.v1.norm() > 100.0, "Solution {} v1 too small", i);
assert!(sol.v2.norm() > 100.0, "Solution {} v2 too small", i);
}
} else {
assert!(result.is_err());
}
}
#[test]
fn test_lambert_batch_empty() {
let mu = 3.986004418e14;
let r1 = Vector3::new(7000e3, 0.0, 0.0);
let r2 = Vector3::new(0.0, 7000e3, 0.0);
let tofs: Vec<f64> = vec![];
let result = Lambert::solve_batch(r1, r2, &tofs, mu, TransferKind::Auto, 0);
if let Ok(solutions) = result {
assert_eq!(solutions.len(), 0, "Empty input should give empty output");
}
let r1s: Vec<Vector3<f64>> = vec![];
let r2s: Vec<Vector3<f64>> = vec![];
let result = Lambert::solve_batch_parallel(&r1s, &r2s, &tofs, mu, TransferKind::Auto, 0);
if let Ok(solutions) = result {
assert_eq!(solutions.len(), 0, "Empty input should give empty output");
}
}
#[test]
fn test_lambert_very_short_tof() {
let mu = 3.986004418e14;
let r1 = Vector3::new(7000e3, 0.0, 0.0);
let r2 = Vector3::new(7001e3, 100e3, 0.0); let tof = 1.0;
let result = Lambert::solve(r1, r2, tof, mu, TransferKind::Auto, 0);
if result.is_ok() {
let solution = result.unwrap();
assert!(solution.v1.norm() > 5000.0);
}
}
#[test]
fn test_lambert_moderate_tof() {
let mu = 3.986004418e14;
let r1 = Vector3::new(7000e3, 0.0, 0.0);
let r2 = Vector3::new(0.0, 7000e3, 0.0);
let period = 2.0 * std::f64::consts::PI * ((7000e3_f64).powi(3) / mu).sqrt();
let tof = period * 0.25;
let solution = Lambert::solve(r1, r2, tof, mu, TransferKind::Auto, 0).unwrap();
assert!(solution.v1.norm() > 1000.0);
assert!(solution.v2.norm() > 1000.0);
let v_esc = (2.0 * mu / r1.norm()).sqrt();
assert!(solution.v1.norm() < v_esc);
assert!(solution.a > 0.0);
}
#[test]
fn test_lambert_stumpff_edge_cases() {
use super::stumpff_functions;
let (c2, c3) = stumpff_functions(0.0);
assert!((c2 - 0.5).abs() < 1e-10);
assert!((c3 - 1.0/6.0).abs() < 1e-10);
let (c2_small, c3_small) = stumpff_functions(0.001);
assert!(c2_small > 0.0); assert!(c3_small > 0.0);
let (c2_neg, c3_neg) = stumpff_functions(-0.001);
assert!(c2_neg > 0.0); assert!(c3_neg < 0.0, "c3 should be negative for hyperbolic orbits");
let (c2_large, c3_large) = stumpff_functions(100.0);
assert!(c2_large > 0.0);
assert!(c3_large > 0.0);
let (c2_hyp, c3_hyp) = stumpff_functions(-100.0);
assert!(c2_hyp > 0.0); assert!(c3_hyp < 0.0, "c3 should be negative for hyperbolic orbits");
}
#[test]
fn test_lambert_hyperbolic_transfer() {
let mu = 3.986004418e14;
let r1 = Vector3::new(7000e3, 0.0, 0.0);
let r2 = Vector3::new(700000e3, 0.0, 0.0); let tof = 86400.0 * 10.0;
let result = Lambert::solve(r1, r2, tof, mu, TransferKind::Auto, 0);
if result.is_ok() {
let solution = result.unwrap();
let v_esc = (2.0 * mu / r1.norm()).sqrt();
println!("Initial velocity: {}, Escape velocity: {}", solution.v1.norm(), v_esc);
assert!(solution.v1.norm() > 0.0);
assert!(solution.v2.norm() > 0.0);
}
}
#[test]
fn test_lambert_long_way_explicit() {
let mu = 3.986004418e14;
let r1 = Vector3::new(7000e3, 0.0, 0.0);
let r2 = Vector3::new(0.0, 7000e3, 0.0);
let period = 2.0 * std::f64::consts::PI * ((7000e3_f64).powi(3) / mu).sqrt();
let tof = period * 0.75;
let solution = Lambert::solve(r1, r2, tof, mu, TransferKind::LongWay, 0);
match solution {
Ok(sol) => {
assert!(!sol.short_way, "Should be long-way transfer");
assert!(sol.v1.norm() > 1000.0);
assert!(sol.v2.norm() > 1000.0);
}
Err(_) => {
}
}
}
#[test]
fn test_lambert_nearly_opposite_vectors() {
let mu = 3.986004418e14;
let r1 = Vector3::new(7000e3, 0.0, 0.0);
let r2 = Vector3::new(-7000e3, 1.0, 0.0);
let tof = 3600.0;
let result = Lambert::solve(r1, r2, tof, mu, TransferKind::Auto, 0);
match result {
Err(e) => {
assert!(e.to_string().contains("opposite") || e.to_string().contains("unique"));
}
Ok(_) => {
}
}
}
#[test]
fn test_lambert_perfectly_opposite_vectors() {
let mu = 3.986004418e14;
let r1 = Vector3::new(7000e3, 0.0, 0.0);
let r2 = Vector3::new(-7000e3, 0.0, 0.0);
let tof = 3600.0;
let result = Lambert::solve(r1, r2, tof, mu, TransferKind::Auto, 0);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.to_string().contains("opposite") || err.to_string().contains("unique"));
}
#[test]
fn test_lambert_izzo_multirev_high_revs() {
let mu = 3.986004418e14;
let r1 = Vector3::new(7000e3, 0.0, 0.0);
let r2 = Vector3::new(0.0, 7000e3, 0.0);
let period = 2.0 * std::f64::consts::PI * ((7000e3_f64).powi(3) / mu).sqrt();
let tof = period * 2.5;
let result = Lambert::solve(r1, r2, tof, mu, TransferKind::Auto, 2);
match result {
Ok(solution) => {
assert!(solution.v1.norm() > 0.0);
assert!(solution.v2.norm() > 0.0);
}
Err(_) => {
}
}
}
#[test]
fn test_lambert_negative_y_adjustment() {
let mu = 3.986004418e14;
let r1 = Vector3::new(7000e3, 0.0, 0.0);
let r2 = Vector3::new(5000e3, 5000e3, 0.0);
let tof = 100.0;
let result = Lambert::solve(r1, r2, tof, mu, TransferKind::Auto, 0);
match result {
Ok(solution) => {
assert!(solution.v1.norm() > 0.0);
assert!(solution.v2.norm() > 0.0);
}
Err(_) => {
}
}
}
#[test]
fn test_lambert_short_way_branch() {
let mu = 3.986004418e14;
let r1 = Vector3::new(7000e3, 0.0, 0.0);
let r2 = Vector3::new(0.0, 7000e3, 0.0);
let period = 2.0 * std::f64::consts::PI * ((7000e3_f64).powi(3) / mu).sqrt();
let tof = period * 0.25;
let solution = Lambert::solve(r1, r2, tof, mu, TransferKind::ShortWay, 0).unwrap();
assert!(solution.short_way);
assert_eq!(solution.revs, 0);
assert!(solution.v1.norm() > 1000.0);
assert!(solution.v2.norm() > 1000.0);
}
#[test]
fn test_lambert_exceeding_max_revs() {
let mu = 3.986004418e14;
let r1 = Vector3::new(7000e3, 0.0, 0.0);
let r2 = Vector3::new(0.0, 7000e3, 0.0);
let period = 2.0 * std::f64::consts::PI * ((7000e3_f64).powi(3) / mu).sqrt();
let tof = period * 0.3;
let result = Lambert::solve(r1, r2, tof, mu, TransferKind::Auto, 10);
assert!(result.is_err());
let err = result.unwrap_err();
let err_str = err.to_string();
assert!(err_str.contains("exceeds maximum") || err_str.contains("revs"));
}
}