use nalgebra::{DMatrix, DVector};
use serde_json::Value;
use std::collections::HashMap;
use crate::math::{
CovarianceInterpolationConfig, interpolate_covariance_sqrt_dmatrix,
interpolate_covariance_two_wasserstein_dmatrix, interpolate_lagrange_dvector,
};
use crate::time::Epoch;
use crate::utils::BraheError;
use super::traits::{
CovarianceInterpolationMethod, InterpolatableTrajectory, InterpolationConfig,
InterpolationMethod, STMStorage, SensitivityStorage, Trajectory, TrajectoryEvictionPolicy,
};
#[derive(Debug, Clone, PartialEq)]
pub struct DTrajectory {
pub epochs: Vec<Epoch>,
pub states: Vec<DVector<f64>>,
pub covariances: Option<Vec<DMatrix<f64>>>,
pub stms: Option<Vec<DMatrix<f64>>>,
pub sensitivities: Option<Vec<DMatrix<f64>>>,
sensitivity_dimension: Option<(usize, usize)>,
pub dimension: usize,
pub interpolation_method: InterpolationMethod,
pub covariance_interpolation_method: CovarianceInterpolationMethod,
pub eviction_policy: TrajectoryEvictionPolicy,
max_size: Option<usize>,
max_age: Option<f64>,
pub metadata: HashMap<String, Value>,
}
impl DTrajectory {
pub fn new(dimension: usize) -> Self {
if dimension == 0 {
panic!("Trajectory dimension must be greater than 0");
}
Self {
epochs: Vec::new(),
states: Vec::new(),
covariances: None,
stms: None,
sensitivities: None,
sensitivity_dimension: None,
dimension,
interpolation_method: InterpolationMethod::Linear,
covariance_interpolation_method: CovarianceInterpolationMethod::TwoWasserstein,
eviction_policy: TrajectoryEvictionPolicy::None,
max_size: None,
max_age: None,
metadata: HashMap::new(),
}
}
pub fn with_interpolation_method(mut self, interpolation_method: InterpolationMethod) -> Self {
self.interpolation_method = interpolation_method;
self
}
pub fn with_eviction_policy_max_size(mut self, max_size: usize) -> Self {
if max_size < 1 {
panic!("Maximum size must be >= 1");
}
self.eviction_policy = TrajectoryEvictionPolicy::KeepCount;
self.max_size = Some(max_size);
self.max_age = None;
self
}
pub fn with_eviction_policy_max_age(mut self, max_age: f64) -> Self {
if max_age <= 0.0 {
panic!("Maximum age must be > 0.0");
}
self.eviction_policy = TrajectoryEvictionPolicy::KeepWithinDuration;
self.max_age = Some(max_age);
self.max_size = None;
self
}
pub fn dimension(&self) -> usize {
self.dimension
}
pub fn to_matrix(&self) -> Result<DMatrix<f64>, BraheError> {
if self.states.is_empty() {
return Err(BraheError::Error(
"Cannot convert empty trajectory to matrix".to_string(),
));
}
let n_epochs = self.states.len();
let n_elements = self.dimension;
let mut matrix = DMatrix::<f64>::zeros(n_epochs, n_elements);
for (row_idx, state) in self.states.iter().enumerate() {
for col_idx in 0..n_elements {
matrix[(row_idx, col_idx)] = state[col_idx];
}
}
Ok(matrix)
}
pub fn enable_covariance_storage(&mut self) {
if self.covariances.is_none() {
let zero_cov = DMatrix::zeros(self.dimension, self.dimension);
self.covariances = Some(vec![zero_cov; self.states.len()]);
}
}
pub fn add_with_covariance(
&mut self,
epoch: Epoch,
state: DVector<f64>,
covariance: DMatrix<f64>,
) {
if state.len() != self.dimension {
panic!("State vector dimension does not match trajectory dimension.");
}
if covariance.nrows() != self.dimension || covariance.ncols() != self.dimension {
panic!(
"Covariance matrix dimensions {}x{} do not match trajectory dimension {}",
covariance.nrows(),
covariance.ncols(),
self.dimension
);
}
if self.covariances.is_none() {
self.enable_covariance_storage();
}
let mut insert_idx = self.epochs.len();
for (i, existing_epoch) in self.epochs.iter().enumerate() {
if epoch < *existing_epoch {
insert_idx = i;
break;
}
}
self.epochs.insert(insert_idx, epoch);
self.states.insert(insert_idx, state.clone());
if let Some(ref mut covs) = self.covariances {
covs.insert(insert_idx, covariance);
}
self.apply_eviction_policy();
}
pub fn set_covariance_at(&mut self, index: usize, covariance: DMatrix<f64>) {
if index >= self.states.len() {
panic!(
"Index {} out of bounds for trajectory with {} states",
index,
self.states.len()
);
}
if covariance.nrows() != self.dimension || covariance.ncols() != self.dimension {
panic!(
"Covariance matrix dimensions {}x{} do not match trajectory dimension {}",
covariance.nrows(),
covariance.ncols(),
self.dimension
);
}
if self.covariances.is_none() {
self.enable_covariance_storage();
}
if let Some(ref mut covs) = self.covariances {
covs[index] = covariance;
}
}
pub fn covariance_at(&self, epoch: Epoch) -> Option<DMatrix<f64>> {
let covs = self.covariances.as_ref()?;
if self.epochs.is_empty() {
return None;
}
if let Some((idx, _)) = self.epochs.iter().enumerate().find(|(_, e)| **e == epoch) {
return Some(covs[idx].clone());
}
let (idx_before, idx_after) = self.find_surrounding_indices(epoch)?;
if self.epochs[idx_before] == epoch {
return Some(covs[idx_before].clone());
}
if self.epochs[idx_after] == epoch {
return Some(covs[idx_after].clone());
}
let t0 = self.epochs[idx_before] - self.epoch_initial()?;
let t1 = self.epochs[idx_after] - self.epoch_initial()?;
let t = epoch - self.epoch_initial()?;
let alpha = (t - t0) / (t1 - t0);
let cov0 = &covs[idx_before];
let cov1 = &covs[idx_after];
let cov = match self.covariance_interpolation_method {
CovarianceInterpolationMethod::MatrixSquareRoot => {
interpolate_covariance_sqrt_dmatrix(cov0, cov1, alpha)
}
CovarianceInterpolationMethod::TwoWasserstein => {
interpolate_covariance_two_wasserstein_dmatrix(cov0, cov1, alpha)
}
};
Some(cov)
}
fn epoch_initial(&self) -> Option<Epoch> {
self.epochs.first().copied()
}
fn find_surrounding_indices(&self, epoch: Epoch) -> Option<(usize, usize)> {
if self.epochs.is_empty() {
return None;
}
if epoch < self.epochs[0] || epoch > *self.epochs.last()? {
return None;
}
for i in 0..self.epochs.len() - 1 {
if self.epochs[i] <= epoch && epoch <= self.epochs[i + 1] {
return Some((i, i + 1));
}
}
None
}
fn apply_eviction_policy(&mut self) {
match self.eviction_policy {
TrajectoryEvictionPolicy::None => {
}
TrajectoryEvictionPolicy::KeepCount => {
if let Some(max_size) = self.max_size
&& self.epochs.len() > max_size
{
let to_remove = self.epochs.len() - max_size;
self.epochs.drain(0..to_remove);
self.states.drain(0..to_remove);
if let Some(ref mut covs) = self.covariances {
covs.drain(0..to_remove);
}
if let Some(ref mut stms) = self.stms {
stms.drain(0..to_remove);
}
if let Some(ref mut sens) = self.sensitivities {
sens.drain(0..to_remove);
}
}
}
TrajectoryEvictionPolicy::KeepWithinDuration => {
if let Some(max_age) = self.max_age
&& let Some(&last_epoch) = self.epochs.last()
{
let mut indices_to_keep = Vec::new();
for (i, &epoch) in self.epochs.iter().enumerate() {
if (last_epoch - epoch).abs() <= max_age {
indices_to_keep.push(i);
}
}
let new_epochs: Vec<Epoch> =
indices_to_keep.iter().map(|&i| self.epochs[i]).collect();
let new_states: Vec<DVector<f64>> = indices_to_keep
.iter()
.map(|&i| self.states[i].clone())
.collect();
self.epochs = new_epochs;
self.states = new_states;
if let Some(ref mut covs) = self.covariances {
let new_covs: Vec<DMatrix<f64>> =
indices_to_keep.iter().map(|&i| covs[i].clone()).collect();
*covs = new_covs;
}
if let Some(ref mut stms) = self.stms {
let new_stms: Vec<DMatrix<f64>> =
indices_to_keep.iter().map(|&i| stms[i].clone()).collect();
*stms = new_stms;
}
if let Some(ref mut sens) = self.sensitivities {
let new_sens: Vec<DMatrix<f64>> =
indices_to_keep.iter().map(|&i| sens[i].clone()).collect();
*sens = new_sens;
}
}
}
}
}
}
impl Default for DTrajectory {
fn default() -> Self {
Self::new(6)
}
}
impl DTrajectory {
pub fn add_full(
&mut self,
epoch: Epoch,
state: DVector<f64>,
covariance: Option<DMatrix<f64>>,
stm: Option<DMatrix<f64>>,
sensitivity: Option<DMatrix<f64>>,
) {
if state.len() != self.dimension {
panic!(
"State vector dimension {} does not match trajectory dimension {}",
state.len(),
self.dimension
);
}
if let Some(ref cov) = covariance {
if cov.nrows() != self.dimension || cov.ncols() != self.dimension {
panic!(
"Covariance dimensions {}×{} do not match expected {}×{}",
cov.nrows(),
cov.ncols(),
self.dimension,
self.dimension
);
}
if self.covariances.is_none() {
self.covariances = Some(vec![
DMatrix::zeros(self.dimension, self.dimension);
self.states.len()
]);
}
}
if let Some(ref stm_val) = stm {
if stm_val.nrows() != self.dimension || stm_val.ncols() != self.dimension {
panic!(
"STM dimensions {}×{} do not match expected {}×{}",
stm_val.nrows(),
stm_val.ncols(),
self.dimension,
self.dimension
);
}
if self.stms.is_none() {
let identity = DMatrix::identity(self.dimension, self.dimension);
self.stms = Some(vec![identity; self.states.len()]);
}
}
if let Some(ref sens) = sensitivity {
if sens.nrows() != self.dimension {
panic!(
"Sensitivity row count {} does not match state dimension {}",
sens.nrows(),
self.dimension
);
}
if let Some((_, existing_cols)) = self.sensitivity_dimension {
if sens.ncols() != existing_cols {
panic!(
"Sensitivity column count {} does not match existing {}",
sens.ncols(),
existing_cols
);
}
} else if self.sensitivities.is_none() {
let zero_sens = DMatrix::zeros(self.dimension, sens.ncols());
self.sensitivities = Some(vec![zero_sens; self.states.len()]);
self.sensitivity_dimension = Some((self.dimension, sens.ncols()));
}
}
let mut insert_idx = self.epochs.len();
for (i, existing_epoch) in self.epochs.iter().enumerate() {
if epoch < *existing_epoch {
insert_idx = i;
break;
}
}
self.epochs.insert(insert_idx, epoch);
self.states.insert(insert_idx, state);
if let Some(ref mut covs) = self.covariances {
if let Some(cov) = covariance {
covs.insert(insert_idx, cov);
} else {
covs.insert(insert_idx, DMatrix::zeros(self.dimension, self.dimension));
}
}
if let Some(ref mut stms) = self.stms {
if let Some(stm_val) = stm {
stms.insert(insert_idx, stm_val);
} else {
stms.insert(
insert_idx,
DMatrix::identity(self.dimension, self.dimension),
);
}
}
if let Some(ref mut sens) = self.sensitivities {
if let Some(sens_val) = sensitivity {
sens.insert(insert_idx, sens_val);
} else if let Some((rows, cols)) = self.sensitivity_dimension {
sens.insert(insert_idx, DMatrix::zeros(rows, cols));
}
}
self.apply_eviction_policy();
}
}
impl std::ops::Index<usize> for DTrajectory {
type Output = DVector<f64>;
fn index(&self, index: usize) -> &Self::Output {
&self.states[index]
}
}
pub struct DTrajectoryIterator<'a> {
trajectory: &'a DTrajectory,
index: usize,
}
impl<'a> Iterator for DTrajectoryIterator<'a> {
type Item = (Epoch, DVector<f64>);
fn next(&mut self) -> Option<Self::Item> {
if self.index < self.trajectory.len() {
let result = self.trajectory.get(self.index).ok();
self.index += 1;
result
} else {
None
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
let remaining = self.trajectory.len() - self.index;
(remaining, Some(remaining))
}
}
impl<'a> ExactSizeIterator for DTrajectoryIterator<'a> {
fn len(&self) -> usize {
self.trajectory.len() - self.index
}
}
impl<'a> IntoIterator for &'a DTrajectory {
type Item = (Epoch, DVector<f64>);
type IntoIter = DTrajectoryIterator<'a>;
fn into_iter(self) -> Self::IntoIter {
DTrajectoryIterator {
trajectory: self,
index: 0,
}
}
}
impl Trajectory for DTrajectory {
type StateVector = DVector<f64>;
fn from_data(epochs: Vec<Epoch>, states: Vec<Self::StateVector>) -> Result<Self, BraheError> {
if epochs.len() != states.len() {
return Err(BraheError::Error(
"Epochs and states vectors must have the same length".to_string(),
));
}
if epochs.is_empty() {
return Err(BraheError::Error(
"Cannot create trajectory from empty data".to_string(),
));
}
let dimension = states[0].len();
if dimension == 0 {
return Err(BraheError::Error(
"State vectors cannot be empty".to_string(),
));
}
for (i, state) in states.iter().enumerate() {
if state.len() != dimension {
return Err(BraheError::Error(format!(
"State {} has dimension {} but expected {}",
i,
state.len(),
dimension
)));
}
}
let mut indices: Vec<usize> = (0..epochs.len()).collect();
indices.sort_by(|&i, &j| epochs[i].partial_cmp(&epochs[j]).unwrap());
let sorted_epochs: Vec<Epoch> = indices.iter().map(|&i| epochs[i]).collect();
let sorted_states: Vec<DVector<f64>> = indices.iter().map(|&i| states[i].clone()).collect();
Ok(Self {
epochs: sorted_epochs,
states: sorted_states,
covariances: None,
stms: None,
sensitivities: None,
sensitivity_dimension: None,
dimension,
interpolation_method: InterpolationMethod::Linear,
covariance_interpolation_method: CovarianceInterpolationMethod::TwoWasserstein,
eviction_policy: TrajectoryEvictionPolicy::None,
max_size: None,
max_age: None,
metadata: HashMap::new(),
})
}
fn add(&mut self, epoch: Epoch, state: DVector<f64>) {
if state.len() != self.dimension {
panic!("State vector dimension does not match trajectory dimension.");
}
let mut insert_idx = self.epochs.len();
for (i, existing_epoch) in self.epochs.iter().enumerate() {
if epoch < *existing_epoch {
insert_idx = i;
break;
}
}
self.epochs.insert(insert_idx, epoch);
self.states.insert(insert_idx, state.clone());
if let Some(ref mut covs) = self.covariances {
covs.insert(insert_idx, DMatrix::zeros(self.dimension, self.dimension));
}
if let Some(ref mut stms) = self.stms {
stms.insert(
insert_idx,
DMatrix::identity(self.dimension, self.dimension),
);
}
if let Some(ref mut sens) = self.sensitivities
&& let Some((rows, cols)) = self.sensitivity_dimension
{
sens.insert(insert_idx, DMatrix::zeros(rows, cols));
}
self.apply_eviction_policy();
}
fn epoch_at_idx(&self, index: usize) -> Result<Epoch, BraheError> {
if index >= self.epochs.len() {
return Err(BraheError::Error(format!(
"Index {} out of bounds for trajectory with {} epochs",
index,
self.epochs.len()
)));
}
Ok(self.epochs[index])
}
fn state_at_idx(&self, index: usize) -> Result<DVector<f64>, BraheError> {
if index >= self.states.len() {
return Err(BraheError::Error(format!(
"Index {} out of bounds for trajectory with {} states",
index,
self.states.len()
)));
}
Ok(self.states[index].clone())
}
fn nearest_state(&self, epoch: &Epoch) -> Result<(Epoch, DVector<f64>), BraheError> {
if self.epochs.is_empty() {
return Err(BraheError::Error(
"Cannot find nearest state in empty trajectory".to_string(),
));
}
let mut nearest_idx = 0;
let mut min_diff = f64::MAX;
for (i, existing_epoch) in self.epochs.iter().enumerate() {
let diff = (*epoch - *existing_epoch).abs();
if diff < min_diff {
min_diff = diff;
nearest_idx = i;
}
if i > 0 && existing_epoch > epoch && diff > min_diff {
break;
}
}
Ok((self.epochs[nearest_idx], self.states[nearest_idx].clone()))
}
fn len(&self) -> usize {
self.states.len()
}
fn start_epoch(&self) -> Option<Epoch> {
self.epochs.first().copied()
}
fn end_epoch(&self) -> Option<Epoch> {
self.epochs.last().copied()
}
fn timespan(&self) -> Option<f64> {
if self.epochs.len() < 2 {
None
} else {
Some(*self.epochs.last().unwrap() - *self.epochs.first().unwrap())
}
}
fn first(&self) -> Option<(Epoch, DVector<f64>)> {
if self.epochs.is_empty() {
None
} else {
Some((self.epochs[0], self.states[0].clone()))
}
}
fn last(&self) -> Option<(Epoch, DVector<f64>)> {
if self.epochs.is_empty() {
None
} else {
let last_index = self.epochs.len() - 1;
Some((self.epochs[last_index], self.states[last_index].clone()))
}
}
fn clear(&mut self) {
self.epochs.clear();
self.states.clear();
if let Some(ref mut covs) = self.covariances {
covs.clear();
}
if let Some(ref mut stms) = self.stms {
stms.clear();
}
if let Some(ref mut sens) = self.sensitivities {
sens.clear();
}
}
fn remove_epoch(&mut self, epoch: &Epoch) -> Result<DVector<f64>, BraheError> {
if let Some(index) = self.epochs.iter().position(|e| e == epoch) {
let removed_state = self.states.remove(index);
self.epochs.remove(index);
if let Some(ref mut covs) = self.covariances {
covs.remove(index);
}
if let Some(ref mut stms) = self.stms {
stms.remove(index);
}
if let Some(ref mut sens) = self.sensitivities {
sens.remove(index);
}
Ok(removed_state)
} else {
Err(BraheError::Error(
"Epoch not found in trajectory".to_string(),
))
}
}
fn remove(&mut self, index: usize) -> Result<(Epoch, DVector<f64>), BraheError> {
if index >= self.states.len() {
return Err(BraheError::Error(format!(
"Index {} out of bounds for trajectory with {} states",
index,
self.states.len()
)));
}
let removed_epoch = self.epochs.remove(index);
let removed_state = self.states.remove(index);
if let Some(ref mut covs) = self.covariances {
covs.remove(index);
}
if let Some(ref mut stms) = self.stms {
stms.remove(index);
}
if let Some(ref mut sens) = self.sensitivities {
sens.remove(index);
}
Ok((removed_epoch, removed_state))
}
fn get(&self, index: usize) -> Result<(Epoch, DVector<f64>), BraheError> {
if index >= self.states.len() {
return Err(BraheError::Error(format!(
"Index {} out of bounds for trajectory with {} states",
index,
self.states.len()
)));
}
Ok((self.epochs[index], self.states[index].clone()))
}
fn index_before_epoch(&self, epoch: &Epoch) -> Result<usize, BraheError> {
if self.epochs.is_empty() {
return Err(BraheError::Error(
"Cannot get index from empty trajectory".to_string(),
));
}
if epoch < &self.epochs[0] {
return Err(BraheError::Error(
"Epoch is before all states in trajectory".to_string(),
));
}
for i in (0..self.epochs.len()).rev() {
if &self.epochs[i] <= epoch {
return Ok(i);
}
}
Err(BraheError::Error(
"Failed to find index before epoch".to_string(),
))
}
fn index_after_epoch(&self, epoch: &Epoch) -> Result<usize, BraheError> {
if self.epochs.is_empty() {
return Err(BraheError::Error(
"Cannot get index from empty trajectory".to_string(),
));
}
if epoch > self.epochs.last().unwrap() {
return Err(BraheError::Error(
"Epoch is after all states in trajectory".to_string(),
));
}
for i in 0..self.epochs.len() {
if &self.epochs[i] >= epoch {
return Ok(i);
}
}
Err(BraheError::Error(
"Failed to find index after epoch".to_string(),
))
}
fn set_eviction_policy_max_size(&mut self, max_size: usize) -> Result<(), BraheError> {
if max_size < 1 {
return Err(BraheError::Error("Maximum size must be >= 1".to_string()));
}
self.eviction_policy = TrajectoryEvictionPolicy::KeepCount;
self.max_size = Some(max_size);
self.max_age = None;
self.apply_eviction_policy();
Ok(())
}
fn set_eviction_policy_max_age(&mut self, max_age: f64) -> Result<(), BraheError> {
if max_age <= 0.0 {
return Err(BraheError::Error("Maximum age must be > 0.0".to_string()));
}
self.eviction_policy = TrajectoryEvictionPolicy::KeepWithinDuration;
self.max_age = Some(max_age);
self.max_size = None;
self.apply_eviction_policy();
Ok(())
}
fn get_eviction_policy(&self) -> TrajectoryEvictionPolicy {
self.eviction_policy
}
}
impl InterpolationConfig for DTrajectory {
fn with_interpolation_method(mut self, method: InterpolationMethod) -> Self {
self.interpolation_method = method;
self
}
fn set_interpolation_method(&mut self, method: InterpolationMethod) {
self.interpolation_method = method;
}
fn get_interpolation_method(&self) -> InterpolationMethod {
self.interpolation_method
}
}
impl CovarianceInterpolationConfig for DTrajectory {
fn with_covariance_interpolation_method(
mut self,
method: CovarianceInterpolationMethod,
) -> Self {
self.covariance_interpolation_method = method;
self
}
fn set_covariance_interpolation_method(&mut self, method: CovarianceInterpolationMethod) {
self.covariance_interpolation_method = method;
}
fn get_covariance_interpolation_method(&self) -> CovarianceInterpolationMethod {
self.covariance_interpolation_method
}
}
impl STMStorage for DTrajectory {
fn enable_stm_storage(&mut self) {
if self.stms.is_none() {
let identity = DMatrix::identity(self.dimension, self.dimension);
self.stms = Some(vec![identity; self.states.len()]);
}
}
fn stm_at_idx(&self, index: usize) -> Option<&DMatrix<f64>> {
self.stms.as_ref()?.get(index)
}
fn set_stm_at(&mut self, index: usize, stm: DMatrix<f64>) {
if index >= self.states.len() {
panic!(
"Index {} out of bounds for trajectory with {} states",
index,
self.states.len()
);
}
if stm.nrows() != self.dimension || stm.ncols() != self.dimension {
panic!(
"STM dimensions {}×{} do not match expected {}×{}",
stm.nrows(),
stm.ncols(),
self.dimension,
self.dimension
);
}
if self.stms.is_none() {
self.enable_stm_storage();
}
if let Some(ref mut stms) = self.stms {
stms[index] = stm;
}
}
fn stm_dimensions(&self) -> (usize, usize) {
(self.dimension, self.dimension)
}
fn stm_storage(&self) -> Option<&Vec<DMatrix<f64>>> {
self.stms.as_ref()
}
fn stm_storage_mut(&mut self) -> Option<&mut Vec<DMatrix<f64>>> {
self.stms.as_mut()
}
}
impl SensitivityStorage for DTrajectory {
fn enable_sensitivity_storage(&mut self, param_dim: usize) {
if param_dim == 0 {
panic!("Parameter dimension must be > 0");
}
if self.sensitivities.is_none() {
let zero_sens = DMatrix::zeros(self.dimension, param_dim);
self.sensitivities = Some(vec![zero_sens; self.states.len()]);
self.sensitivity_dimension = Some((self.dimension, param_dim));
}
}
fn sensitivity_at_idx(&self, index: usize) -> Option<&DMatrix<f64>> {
self.sensitivities.as_ref()?.get(index)
}
fn set_sensitivity_at(&mut self, index: usize, sensitivity: DMatrix<f64>) {
if index >= self.states.len() {
panic!(
"Index {} out of bounds for trajectory with {} states",
index,
self.states.len()
);
}
if sensitivity.nrows() != self.dimension {
panic!(
"Sensitivity row count {} does not match state dimension {}",
sensitivity.nrows(),
self.dimension
);
}
if let Some((_, existing_cols)) = self.sensitivity_dimension
&& sensitivity.ncols() != existing_cols
{
panic!(
"Sensitivity column count {} does not match existing {}",
sensitivity.ncols(),
existing_cols
);
}
if self.sensitivities.is_none() {
self.enable_sensitivity_storage(sensitivity.ncols());
}
if let Some(ref mut sens) = self.sensitivities {
sens[index] = sensitivity;
}
}
fn sensitivity_dimensions(&self) -> Option<(usize, usize)> {
self.sensitivity_dimension
}
fn sensitivity_storage(&self) -> Option<&Vec<DMatrix<f64>>> {
self.sensitivities.as_ref()
}
fn sensitivity_storage_mut(&mut self) -> Option<&mut Vec<DMatrix<f64>>> {
self.sensitivities.as_mut()
}
}
impl InterpolatableTrajectory for DTrajectory {
fn interpolate(&self, epoch: &Epoch) -> Result<DVector<f64>, BraheError> {
if let Some(start) = self.start_epoch()
&& *epoch < start
{
return Err(BraheError::OutOfBoundsError(format!(
"Cannot interpolate: epoch {} is before trajectory start {}",
epoch, start
)));
}
if let Some(end) = self.end_epoch()
&& *epoch > end
{
return Err(BraheError::OutOfBoundsError(format!(
"Cannot interpolate: epoch {} is after trajectory end {}",
epoch, end
)));
}
let idx1 = self.index_before_epoch(epoch)?;
let idx2 = self.index_after_epoch(epoch)?;
if idx1 == idx2 {
return self.state_at_idx(idx1);
}
let method = self.get_interpolation_method();
let required = method.min_points_required();
if self.len() < required {
return Err(BraheError::Error(format!(
"{:?} requires {} points, trajectory has {}",
method,
required,
self.len()
)));
}
let ref_epoch = self.start_epoch().unwrap();
match method {
InterpolationMethod::Linear => self.interpolate_linear(epoch),
InterpolationMethod::Lagrange { degree } => {
let n_points = degree + 1;
let (start_idx, end_idx) =
compute_lagrange_window(self.len(), idx1, idx2, n_points)?;
let times: Vec<f64> = (start_idx..=end_idx)
.map(|i| self.epochs[i] - ref_epoch)
.collect();
let values: Vec<DVector<f64>> = (start_idx..=end_idx)
.map(|i| self.states[i].clone())
.collect();
let t = *epoch - ref_epoch;
Ok(interpolate_lagrange_dvector(×, &values, t))
}
InterpolationMethod::HermiteCubic | InterpolationMethod::HermiteQuintic => {
Err(BraheError::Error(format!(
"{:?} interpolation requires 6D orbital states with position/velocity \
structure. Use DOrbitTrajectory for orbital states with Hermite methods, \
or use Linear/Lagrange interpolation for generic N-dimensional systems.",
self.interpolation_method
)))
}
}
}
}
fn compute_lagrange_window(
len: usize,
idx1: usize,
idx2: usize,
n_points: usize,
) -> Result<(usize, usize), BraheError> {
if len < n_points {
return Err(BraheError::Error(format!(
"Need {} points for interpolation, trajectory has {}",
n_points, len
)));
}
let center = (idx1 + idx2) / 2;
let half_window = n_points / 2;
let mut start_idx = center.saturating_sub(half_window);
let mut end_idx = start_idx + n_points - 1;
if end_idx >= len {
end_idx = len - 1;
start_idx = end_idx.saturating_sub(n_points - 1);
}
Ok((start_idx, end_idx))
}
#[cfg(test)]
#[cfg_attr(coverage_nightly, coverage(off))]
mod tests {
use super::*;
use crate::time::{Epoch, TimeSystem};
use approx::assert_abs_diff_eq;
fn create_test_trajectory() -> DTrajectory {
let epochs = vec![
Epoch::from_jd(2451545.0, TimeSystem::UTC),
Epoch::from_jd(2451545.1, TimeSystem::UTC),
Epoch::from_jd(2451545.2, TimeSystem::UTC),
];
let states = vec![
DVector::from_vec(vec![7000e3, 0.0, 0.0, 0.0, 7.5e3, 0.0]),
DVector::from_vec(vec![7100e3, 1000e3, 500e3, 100.0, 7.6e3, 50.0]),
DVector::from_vec(vec![7200e3, 2000e3, 1000e3, 200.0, 7.7e3, 100.0]),
];
DTrajectory::from_data(epochs, states).unwrap()
}
#[test]
fn test_dtrajectory_new_with_dimension() {
let traj = DTrajectory::new(3);
assert_eq!(traj.dimension, 3);
assert_eq!(traj.len(), 0);
assert!(traj.is_empty());
let traj = DTrajectory::new(6);
assert_eq!(traj.dimension, 6);
assert_eq!(traj.len(), 0);
assert!(traj.is_empty());
let traj = DTrajectory::new(12);
assert_eq!(traj.dimension, 12);
assert_eq!(traj.len(), 0);
assert!(traj.is_empty());
}
#[test]
#[should_panic(expected = "Trajectory dimension must be greater than 0")]
fn test_dtrajectory_new_with_zero_dimension() {
let _traj = DTrajectory::new(0);
}
#[test]
fn test_dtrajectory_with_interpolation_method() {
let traj = DTrajectory::new(12).with_interpolation_method(InterpolationMethod::Linear);
assert_eq!(traj.dimension, 12);
assert_eq!(traj.interpolation_method, InterpolationMethod::Linear);
}
#[test]
fn test_dtrajectory_with_eviction_policy_max_size_builder() {
let traj = DTrajectory::new(6).with_eviction_policy_max_size(5);
assert_eq!(
traj.get_eviction_policy(),
TrajectoryEvictionPolicy::KeepCount
);
assert_eq!(traj.len(), 0);
}
#[test]
fn test_dtrajectory_with_eviction_policy_max_age_builder() {
let traj = DTrajectory::new(6).with_eviction_policy_max_age(300.0);
assert_eq!(
traj.get_eviction_policy(),
TrajectoryEvictionPolicy::KeepWithinDuration
);
assert_eq!(traj.len(), 0);
}
#[test]
fn test_dtrajectory_builder_pattern_chaining() {
let mut traj = DTrajectory::new(6)
.with_interpolation_method(InterpolationMethod::Linear)
.with_eviction_policy_max_size(10);
assert_eq!(traj.get_interpolation_method(), InterpolationMethod::Linear);
assert_eq!(
traj.get_eviction_policy(),
TrajectoryEvictionPolicy::KeepCount
);
let t0 = Epoch::from_datetime(2023, 1, 1, 12, 0, 0.0, 0.0, TimeSystem::UTC);
for i in 0..15 {
let epoch = t0 + (i as f64 * 60.0);
let state =
DVector::from_vec(vec![7000e3 + i as f64 * 1000.0, 0.0, 0.0, 0.0, 7.5e3, 0.0]);
traj.add(epoch, state);
}
assert_eq!(traj.len(), 10);
}
#[test]
fn test_dtrajectory_dimension() {
let traj = DTrajectory::new(9);
assert_eq!(traj.dimension(), 9);
let traj = DTrajectory::new(4);
assert_eq!(traj.dimension(), 4);
}
#[test]
fn test_dtrajectory_interpolatable_set_interpolation_method() {
let mut traj = DTrajectory::new(6);
assert_eq!(traj.interpolation_method, InterpolationMethod::Linear);
traj.set_interpolation_method(InterpolationMethod::Linear);
assert_eq!(traj.interpolation_method, InterpolationMethod::Linear);
}
#[test]
fn test_dtrajectory_to_matrix() {
let traj = create_test_trajectory();
let matrix = traj.to_matrix().unwrap();
assert_eq!(matrix.nrows(), 3);
assert_eq!(matrix.ncols(), 6);
assert_abs_diff_eq!(matrix[(0, 0)], 7000e3, epsilon = 1.0);
assert_abs_diff_eq!(matrix[(0, 1)], 0.0, epsilon = 1.0);
assert_abs_diff_eq!(matrix[(0, 2)], 0.0, epsilon = 1.0);
assert_abs_diff_eq!(matrix[(0, 3)], 0.0, epsilon = 1.0);
assert_abs_diff_eq!(matrix[(0, 4)], 7.5e3, epsilon = 1.0);
assert_abs_diff_eq!(matrix[(0, 5)], 0.0, epsilon = 1.0);
assert_abs_diff_eq!(matrix[(1, 0)], 7100e3, epsilon = 1.0);
assert_abs_diff_eq!(matrix[(1, 1)], 1000e3, epsilon = 1.0);
assert_abs_diff_eq!(matrix[(2, 0)], 7200e3, epsilon = 1.0);
assert_abs_diff_eq!(matrix[(2, 1)], 2000e3, epsilon = 1.0);
assert_abs_diff_eq!(matrix[(2, 2)], 1000e3, epsilon = 1.0);
assert_abs_diff_eq!(matrix[(2, 3)], 200.0, epsilon = 1.0);
assert_abs_diff_eq!(matrix[(2, 4)], 7.7e3, epsilon = 1.0);
assert_abs_diff_eq!(matrix[(2, 5)], 100.0, epsilon = 1.0);
assert_abs_diff_eq!(matrix[(0, 0)], 7000e3, epsilon = 1.0);
assert_abs_diff_eq!(matrix[(1, 0)], 7100e3, epsilon = 1.0);
assert_abs_diff_eq!(matrix[(2, 0)], 7200e3, epsilon = 1.0);
}
#[test]
fn test_dtrajectory_trajectory_get_eviction_policy() {
let mut traj = DTrajectory::new(6);
assert_eq!(traj.get_eviction_policy(), TrajectoryEvictionPolicy::None);
traj.set_eviction_policy_max_size(10).unwrap();
assert_eq!(
traj.get_eviction_policy(),
TrajectoryEvictionPolicy::KeepCount
);
traj.set_eviction_policy_max_age(100.0).unwrap();
assert_eq!(
traj.get_eviction_policy(),
TrajectoryEvictionPolicy::KeepWithinDuration
);
}
#[test]
fn test_dtrajectory_apply_eviction_policy_keep_count() {
let mut traj = DTrajectory::new(6).with_eviction_policy_max_size(3);
let t0 = Epoch::from_datetime(2023, 1, 1, 12, 0, 0.0, 0.0, TimeSystem::UTC);
for i in 0..5 {
let epoch = t0 + (i as f64 * 60.0);
let state =
DVector::from_vec(vec![7000e3 + i as f64 * 1000.0, 0.0, 0.0, 0.0, 7.5e3, 0.0]);
traj.add(epoch, state);
}
assert_eq!(traj.len(), 3);
assert_eq!(traj.epochs[0], t0 + 2.0 * 60.0); }
#[test]
fn test_dtrajectory_apply_eviction_policy_keep_within_duration() {
let mut traj = DTrajectory::new(6).with_eviction_policy_max_age(86400.0 * 7.0 - 1.0);
let t0 = Epoch::from_datetime(2023, 1, 1, 0, 0, 0.0, 0.0, TimeSystem::UTC);
for i in 0..10 {
let epoch = t0 + (i as f64 * 86400.0); let state =
DVector::from_vec(vec![7000e3 + i as f64 * 1000.0, 0.0, 0.0, 0.0, 7.5e3, 0.0]);
traj.add(epoch, state);
}
assert_eq!(traj.len(), 7);
assert_eq!(traj.epochs[0], t0 + 3.0 * 86400.0);
let mut traj = DTrajectory::new(6).with_eviction_policy_max_age(86400.0 * 7.0); for i in 0..10 {
let epoch = t0 + (i as f64 * 86400.0);
let state =
DVector::from_vec(vec![7000e3 + i as f64 * 1000.0, 0.0, 0.0, 0.0, 7.5e3, 0.0]);
traj.add(epoch, state);
}
assert_eq!(traj.len(), 8);
assert_eq!(traj.epochs[0], t0 + 2.0 * 86400.0); }
#[test]
fn test_dtrajectory_default() {
let traj = DTrajectory::default();
assert_eq!(traj.dimension, 6);
assert_eq!(traj.len(), 0);
assert!(traj.is_empty());
assert_eq!(traj.interpolation_method, InterpolationMethod::Linear);
assert_eq!(traj.eviction_policy, TrajectoryEvictionPolicy::None);
}
#[test]
fn test_dtrajectory_index() {
let traj = create_test_trajectory();
let state = &traj[0];
assert_eq!(state.len(), 6);
assert_eq!(state[0], 7000e3);
assert_eq!(state[1], 0.0);
assert_eq!(state[2], 0.0);
assert_eq!(state[3], 0.0);
assert_eq!(state[4], 7.5e3);
assert_eq!(state[5], 0.0);
let state = &traj[1];
assert_eq!(state[0], 7100e3);
assert_eq!(state[1], 1000e3);
assert_eq!(state[2], 500e3);
assert_eq!(state[3], 100.0);
assert_eq!(state[4], 7.6e3);
assert_eq!(state[5], 50.0);
let state = &traj[2];
assert_eq!(state[0], 7200e3);
assert_eq!(state[1], 2000e3);
assert_eq!(state[2], 1000e3);
assert_eq!(state[3], 200.0);
assert_eq!(state[4], 7.7e3);
assert_eq!(state[5], 100.0);
}
#[test]
#[should_panic]
fn test_dtrajectory_index_index_out_of_bounds() {
let traj = create_test_trajectory();
let _ = &traj[10]; }
#[test]
fn test_dtrajectory_iterator_iterator_len() {
let traj = create_test_trajectory();
let iter = traj.into_iter();
assert_eq!(iter.len(), 3);
}
#[test]
fn test_dtrajectory_iterator_iterator_size_hint() {
let traj = create_test_trajectory();
let iter = traj.into_iter();
let (lower, upper) = iter.size_hint();
assert_eq!(lower, 3);
assert_eq!(upper, Some(3));
}
#[test]
fn test_dtrajectory_exactsizeiterator_len() {
let traj = create_test_trajectory();
let iter = traj.into_iter();
assert_eq!(iter.len(), 3);
}
#[test]
fn test_dtrajectory_intoiterator_into_iter() {
let traj = create_test_trajectory();
let mut count = 0;
for (epoch, state) in &traj {
match count {
0 => {
assert_eq!(epoch.jd(), 2451545.0);
assert_abs_diff_eq!(state[0], 7000e3, epsilon = 1.0);
}
1 => {
assert_eq!(epoch.jd(), 2451545.1);
assert_abs_diff_eq!(state[0], 7100e3, epsilon = 1.0);
}
2 => {
assert_eq!(epoch.jd(), 2451545.2);
assert_abs_diff_eq!(state[0], 7200e3, epsilon = 1.0);
}
_ => panic!("Too many iterations"),
}
count += 1;
}
assert_eq!(count, 3);
}
#[test]
fn test_dtrajectory_intoiterator_into_iter_empty() {
let traj = DTrajectory::new(6);
let mut count = 0;
for _ in &traj {
count += 1;
}
assert_eq!(count, 0);
}
#[test]
fn test_dtrajectory_from_data() {
let epochs = vec![
Epoch::from_jd(2451545.0, TimeSystem::UTC),
Epoch::from_jd(2451545.1, TimeSystem::UTC),
];
let states = vec![
DVector::from_vec(vec![1.0, 2.0, 3.0]),
DVector::from_vec(vec![4.0, 5.0, 6.0]),
];
let traj = DTrajectory::from_data(epochs, states).unwrap();
assert_eq!(traj.dimension, 3);
assert_eq!(traj.len(), 2);
}
#[test]
fn test_dtrajectory_from_data_errors() {
let epochs = vec![
Epoch::from_jd(2451545.0, TimeSystem::UTC),
Epoch::from_jd(2451545.1, TimeSystem::UTC),
];
let states = vec![DVector::from_vec(vec![1.0, 2.0, 3.0])];
let result = DTrajectory::from_data(epochs.clone(), states);
assert!(result.is_err());
let empty_epochs: Vec<Epoch> = vec![];
let empty_states: Vec<DVector<f64>> = vec![];
let result = DTrajectory::from_data(empty_epochs, empty_states);
assert!(result.is_err());
}
#[test]
fn test_dtrajectory_trajectory_add() {
let mut trajectory = DTrajectory::new(6);
let epoch1 = Epoch::from_datetime(2023, 1, 1, 12, 0, 0.0, 0.0, TimeSystem::UTC);
let state1 = DVector::from_vec(vec![7000e3, 0.0, 0.0, 0.0, 7.5e3, 0.0]);
trajectory.add(epoch1, state1.clone());
assert_eq!(trajectory.len(), 1);
let epoch2 = Epoch::from_datetime(2023, 1, 1, 13, 0, 0.0, 0.0, TimeSystem::UTC);
let state2 = DVector::from_vec(vec![7100e3, 100e3, 50e3, 10.0, 7.6e3, 5.0]);
trajectory.add(epoch2, state2.clone());
assert_eq!(trajectory.len(), 2);
assert_eq!(trajectory.states[0], state1);
assert_eq!(trajectory.states[1], state2);
}
#[test]
fn test_dtrajectory_trajectory_add_out_of_order() {
let mut trajectory = DTrajectory::new(6);
let epoch1 = Epoch::from_datetime(2023, 1, 1, 13, 0, 0.0, 0.0, TimeSystem::UTC);
let state1 = DVector::from_vec(vec![7100e3, 100e3, 60e3, 10.0, 7.6e3, 5.0]);
trajectory.add(epoch1, state1.clone());
assert_eq!(trajectory.len(), 1);
assert_eq!(trajectory.epochs[0], epoch1);
assert_eq!(trajectory.states[0], state1);
let epoch2 = Epoch::from_datetime(2023, 1, 1, 12, 0, 0.0, 0.0, TimeSystem::UTC);
let state2 = DVector::from_vec(vec![7000e3, 0.0, 0.0, 0.0, 7.5e3, 0.0]);
trajectory.add(epoch2, state2.clone());
assert_eq!(trajectory.len(), 2);
assert_eq!(trajectory.epochs[0], epoch2);
assert_eq!(trajectory.states[0], state2);
assert_eq!(trajectory.epochs[1], epoch1);
assert_eq!(trajectory.states[1], state1);
}
#[test]
#[should_panic(expected = "State vector dimension does not match trajectory dimension")]
fn test_dtrajectory_trajectory_add_dimension_mismatch() {
let mut trajectory = DTrajectory::new(6);
let epoch = Epoch::from_datetime(2023, 1, 1, 12, 0, 0.0, 0.0, TimeSystem::UTC);
let state = DVector::from_vec(vec![7000e3, 0.0, 0.0]);
trajectory.add(epoch, state);
}
#[test]
fn test_dtrajectory_trajectory_add_append() {
let mut trajectory = DTrajectory::new(6);
let epoch = Epoch::from_datetime(2023, 1, 1, 12, 0, 0.0, 0.0, TimeSystem::UTC);
let state1 = DVector::from_vec(vec![7000e3, 0.0, 0.0, 0.0, 7.5e3, 0.0]);
trajectory.add(epoch, state1.clone());
assert_eq!(trajectory.len(), 1);
assert_eq!(trajectory.states[0], state1);
let state2 = DVector::from_vec(vec![7100e3, 100e3, 50e3, 10.0, 7.6e3, 5.0]);
trajectory.add(epoch, state2.clone());
assert_eq!(trajectory.len(), 2); assert_eq!(trajectory.states[0], state1); assert_eq!(trajectory.states[1], state2); }
#[test]
fn test_dtrajectory_trajectory_epoch() {
let traj = create_test_trajectory();
let epoch = traj.epoch_at_idx(0).unwrap();
assert_eq!(epoch, Epoch::from_jd(2451545.0, TimeSystem::UTC));
let epoch = traj.epoch_at_idx(1).unwrap();
assert_eq!(epoch, Epoch::from_jd(2451545.1, TimeSystem::UTC));
}
#[test]
fn test_dtrajectory_trajectory_state() {
let traj = create_test_trajectory();
let state = traj.state_at_idx(0).unwrap();
assert_abs_diff_eq!(state[0], 7000e3, epsilon = 1.0);
let state = traj.state_at_idx(1).unwrap();
assert_abs_diff_eq!(state[0], 7100e3, epsilon = 1.0);
}
#[test]
fn test_dtrajectory_trajectory_nearest_state() {
let traj = create_test_trajectory();
let epoch = Epoch::from_jd(2451545.05, TimeSystem::UTC);
let (nearest_epoch, _) = traj.nearest_state(&epoch).unwrap();
assert_eq!(nearest_epoch, Epoch::from_jd(2451545.0, TimeSystem::UTC));
let epoch = Epoch::from_jd(2451545.09, TimeSystem::UTC);
let (nearest_epoch, _) = traj.nearest_state(&epoch).unwrap();
assert_eq!(nearest_epoch, Epoch::from_jd(2451545.1, TimeSystem::UTC));
let epoch = Epoch::from_jd(2451545.11, TimeSystem::UTC);
let (nearest_epoch, _) = traj.nearest_state(&epoch).unwrap();
assert_eq!(nearest_epoch, Epoch::from_jd(2451545.1, TimeSystem::UTC));
let epoch = Epoch::from_jd(2451545.2, TimeSystem::UTC);
let (nearest_epoch, _) = traj.nearest_state(&epoch).unwrap();
assert_eq!(nearest_epoch, Epoch::from_jd(2451545.2, TimeSystem::UTC));
}
#[test]
fn test_dtrajectory_trajectory_len() {
let traj = create_test_trajectory();
assert_eq!(traj.len(), 3);
let empty_traj = DTrajectory::new(6);
assert_eq!(empty_traj.len(), 0);
}
#[test]
fn test_dtrajectory_trajectory_is_empty() {
let traj = create_test_trajectory();
assert!(!traj.is_empty());
let empty_traj = DTrajectory::new(6);
assert!(empty_traj.is_empty());
}
#[test]
fn test_dtrajectory_trajectory_start_epoch() {
let traj = create_test_trajectory();
let start = traj.start_epoch().unwrap();
assert_eq!(start, Epoch::from_jd(2451545.0, TimeSystem::UTC));
let empty_traj = DTrajectory::new(6);
assert!(empty_traj.start_epoch().is_none());
}
#[test]
fn test_dtrajectory_trajectory_end_epoch() {
let traj = create_test_trajectory();
let end = traj.end_epoch().unwrap();
assert_eq!(end, Epoch::from_jd(2451545.2, TimeSystem::UTC));
let empty_traj = DTrajectory::new(6);
assert!(empty_traj.end_epoch().is_none());
}
#[test]
fn test_dtrajectory_trajectory_timespan() {
let traj = create_test_trajectory();
let timespan = traj.timespan().unwrap();
assert_abs_diff_eq!(timespan, 0.2 * 86400.0, epsilon = 1.0);
let empty_traj = DTrajectory::new(6);
assert!(empty_traj.timespan().is_none());
}
#[test]
fn test_dtrajectory_trajectory_first() {
let traj = create_test_trajectory();
let (epoch, state) = traj.first().unwrap();
assert_eq!(epoch, Epoch::from_jd(2451545.0, TimeSystem::UTC));
assert_abs_diff_eq!(state[0], 7000e3, epsilon = 1.0);
let empty_traj = DTrajectory::new(6);
assert!(empty_traj.first().is_none());
}
#[test]
fn test_dtrajectory_trajectory_last() {
let traj = create_test_trajectory();
let (epoch, state) = traj.last().unwrap();
assert_eq!(epoch, Epoch::from_jd(2451545.2, TimeSystem::UTC));
assert_abs_diff_eq!(state[0], 7200e3, epsilon = 1.0);
let empty_traj = DTrajectory::new(6);
assert!(empty_traj.last().is_none());
}
#[test]
fn test_dtrajectory_trajectory_clear() {
let mut traj = create_test_trajectory();
assert_eq!(traj.len(), 3);
traj.clear();
assert_eq!(traj.len(), 0);
assert!(traj.is_empty());
}
#[test]
fn test_dtrajectory_trajectory_remove_epoch() {
let mut traj = create_test_trajectory();
let epoch = Epoch::from_jd(2451545.1, TimeSystem::UTC);
let removed_state = traj.remove_epoch(&epoch).unwrap();
assert_abs_diff_eq!(removed_state[0], 7100e3, epsilon = 1.0);
assert_eq!(traj.len(), 2);
}
#[test]
fn test_dtrajectory_trajectory_remove() {
let mut traj = create_test_trajectory();
let (removed_epoch, removed_state) = traj.remove(1).unwrap();
assert_eq!(removed_epoch, Epoch::from_jd(2451545.1, TimeSystem::UTC));
assert_abs_diff_eq!(removed_state[0], 7100e3, epsilon = 1.0);
assert_eq!(traj.len(), 2);
}
#[test]
fn test_dtrajectory_trajectory_remove_out_of_bounds() {
let mut traj = create_test_trajectory();
let result = traj.remove(10);
assert!(result.is_err());
}
#[test]
fn test_dtrajectory_trajectory_get() {
let traj = create_test_trajectory();
let (epoch, state) = traj.get(1).unwrap();
assert_eq!(epoch, Epoch::from_jd(2451545.1, TimeSystem::UTC));
assert_abs_diff_eq!(state[0], 7100e3, epsilon = 1.0);
}
#[test]
fn test_dtrajectory_trajectory_index_before_epoch() {
let t0 = Epoch::from_datetime(2023, 1, 1, 12, 0, 0.0, 0.0, TimeSystem::UTC);
let t1 = t0 + 60.0;
let t2 = t0 + 120.0;
let epochs = vec![t0, t1, t2];
let states = vec![
DVector::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]),
DVector::from_vec(vec![11.0, 12.0, 13.0, 14.0, 15.0, 16.0]),
DVector::from_vec(vec![21.0, 22.0, 23.0, 24.0, 25.0, 26.0]),
];
let traj = DTrajectory::from_data(epochs, states).unwrap();
let before_t0 = t0 - 10.0;
assert!(traj.index_before_epoch(&before_t0).is_err());
let t0_plus_30 = t0 + 30.0;
assert_eq!(traj.index_before_epoch(&t0_plus_30).unwrap(), 0);
assert_eq!(traj.index_before_epoch(&t1).unwrap(), 1);
let t0_plus_90 = t0 + 90.0;
assert_eq!(traj.index_before_epoch(&t0_plus_90).unwrap(), 1);
assert_eq!(traj.index_before_epoch(&t2).unwrap(), 2);
let t0_plus_150 = t0 + 150.0;
assert_eq!(traj.index_before_epoch(&t0_plus_150).unwrap(), 2);
}
#[test]
fn test_dtrajectory_trajectory_index_after_epoch() {
let t0 = Epoch::from_datetime(2023, 1, 1, 12, 0, 0.0, 0.0, TimeSystem::UTC);
let t1 = t0 + 60.0;
let t2 = t0 + 120.0;
let epochs = vec![t0, t1, t2];
let states = vec![
DVector::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]),
DVector::from_vec(vec![11.0, 12.0, 13.0, 14.0, 15.0, 16.0]),
DVector::from_vec(vec![21.0, 22.0, 23.0, 24.0, 25.0, 26.0]),
];
let traj = DTrajectory::from_data(epochs, states).unwrap();
let t0_minus_30 = t0 - 30.0;
assert_eq!(traj.index_after_epoch(&t0_minus_30).unwrap(), 0);
assert_eq!(traj.index_after_epoch(&t0).unwrap(), 0);
let t0_plus_30 = t0 + 30.0;
assert_eq!(traj.index_after_epoch(&t0_plus_30).unwrap(), 1);
assert_eq!(traj.index_after_epoch(&t1).unwrap(), 1);
let t0_plus_90 = t0 + 90.0;
assert_eq!(traj.index_after_epoch(&t0_plus_90).unwrap(), 2);
assert_eq!(traj.index_after_epoch(&t2).unwrap(), 2);
let t0_plus_150 = t0 + 150.0;
assert!(traj.index_after_epoch(&t0_plus_150).is_err());
}
#[test]
fn test_dtrajectory_trajectory_state_before_epoch() {
let t0 = Epoch::from_datetime(2023, 1, 1, 12, 0, 0.0, 0.0, TimeSystem::UTC);
let t1 = t0 + 60.0;
let t2 = t0 + 120.0;
let epochs = vec![t0, t1, t2];
let states = vec![
DVector::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]),
DVector::from_vec(vec![11.0, 12.0, 13.0, 14.0, 15.0, 16.0]),
DVector::from_vec(vec![21.0, 22.0, 23.0, 24.0, 25.0, 26.0]),
];
let traj = DTrajectory::from_data(epochs, states).unwrap();
let t0_plus_30 = t0 + 30.0;
let (epoch, state) = traj.state_before_epoch(&t0_plus_30).unwrap();
assert_eq!(epoch, t0);
assert_eq!(state[0], 1.0);
let t0_plus_90 = t0 + 90.0;
let (epoch, state) = traj.state_before_epoch(&t0_plus_90).unwrap();
assert_eq!(epoch, t1);
assert_eq!(state[0], 11.0);
let before_t0 = t0 - 10.0;
assert!(traj.state_before_epoch(&before_t0).is_err());
let (epoch, state) = traj.state_before_epoch(&t1).unwrap();
assert_eq!(epoch, t1);
assert_eq!(state[0], 11.0);
}
#[test]
fn test_dtrajectory_trajectory_state_after_epoch() {
let t0 = Epoch::from_datetime(2023, 1, 1, 12, 0, 0.0, 0.0, TimeSystem::UTC);
let t1 = t0 + 60.0;
let t2 = t0 + 120.0;
let epochs = vec![t0, t1, t2];
let states = vec![
DVector::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]),
DVector::from_vec(vec![11.0, 12.0, 13.0, 14.0, 15.0, 16.0]),
DVector::from_vec(vec![21.0, 22.0, 23.0, 24.0, 25.0, 26.0]),
];
let traj = DTrajectory::from_data(epochs, states).unwrap();
let t0_plus_30 = t0 + 30.0;
let (epoch, state) = traj.state_after_epoch(&t0_plus_30).unwrap();
assert_eq!(epoch, t1);
assert_eq!(state[0], 11.0);
let t0_plus_90 = t0 + 90.0;
let (epoch, state) = traj.state_after_epoch(&t0_plus_90).unwrap();
assert_eq!(epoch, t2);
assert_eq!(state[0], 21.0);
let after_t2 = t2 + 10.0;
assert!(traj.state_after_epoch(&after_t2).is_err());
let (epoch, state) = traj.state_after_epoch(&t1).unwrap();
assert_eq!(epoch, t1);
assert_eq!(state[0], 11.0);
}
#[test]
fn test_dtrajectory_set_eviction_policy_max_size() {
let mut traj = create_test_trajectory();
assert_eq!(traj.len(), 3);
let _ = traj.set_eviction_policy_max_size(2);
assert_eq!(traj.len(), 2);
assert_eq!(traj.eviction_policy, TrajectoryEvictionPolicy::KeepCount);
}
#[test]
fn test_dtrajectory_set_eviction_policy_max_age() {
let mut traj = create_test_trajectory();
let _ = traj.set_eviction_policy_max_age(0.11 * 86400.0);
assert_eq!(traj.len(), 2);
assert_eq!(
traj.eviction_policy,
TrajectoryEvictionPolicy::KeepWithinDuration
);
}
#[test]
fn test_dtrajectory_interpolatable_get_interpolation_method() {
let mut traj = DTrajectory::new(6);
assert_eq!(traj.get_interpolation_method(), InterpolationMethod::Linear);
traj.set_interpolation_method(InterpolationMethod::Linear);
assert_eq!(traj.get_interpolation_method(), InterpolationMethod::Linear);
}
#[test]
fn test_dtrajectory_interpolatable_interpolate_linear() {
let t0 = Epoch::from_datetime(2023, 1, 1, 12, 0, 0.0, 0.0, TimeSystem::UTC);
let t1 = t0 + 60.0;
let t2 = t0 + 120.0;
let epochs = vec![t0, t1, t2];
let states = vec![
DVector::from_vec(vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0]),
DVector::from_vec(vec![60.0, 120.0, 180.0, 240.0, 300.0, 360.0]),
DVector::from_vec(vec![120.0, 240.0, 360.0, 480.0, 600.0, 720.0]),
];
let traj = DTrajectory::from_data(epochs, states).unwrap();
let state_at_t0 = traj.interpolate_linear(&t0).unwrap();
assert_abs_diff_eq!(state_at_t0[0], 0.0, epsilon = 1e-10);
assert_abs_diff_eq!(state_at_t0[1], 0.0, epsilon = 1e-10);
let state_at_t1 = traj.interpolate_linear(&t1).unwrap();
assert_abs_diff_eq!(state_at_t1[0], 60.0, epsilon = 1e-10);
assert_abs_diff_eq!(state_at_t1[1], 120.0, epsilon = 1e-10);
let state_at_t2 = traj.interpolate_linear(&t2).unwrap();
assert_abs_diff_eq!(state_at_t2[0], 120.0, epsilon = 1e-10);
assert_abs_diff_eq!(state_at_t2[1], 240.0, epsilon = 1e-10);
let t0_plus_30 = t0 + 30.0;
let state_at_midpoint = traj.interpolate_linear(&t0_plus_30).unwrap();
assert_abs_diff_eq!(state_at_midpoint[0], 30.0, epsilon = 1e-10);
assert_abs_diff_eq!(state_at_midpoint[1], 60.0, epsilon = 1e-10);
assert_abs_diff_eq!(state_at_midpoint[2], 90.0, epsilon = 1e-10);
assert_abs_diff_eq!(state_at_midpoint[3], 120.0, epsilon = 1e-10);
assert_abs_diff_eq!(state_at_midpoint[4], 150.0, epsilon = 1e-10);
assert_abs_diff_eq!(state_at_midpoint[5], 180.0, epsilon = 1e-10);
let t1_plus_30 = t1 + 30.0;
let state_at_midpoint2 = traj.interpolate_linear(&t1_plus_30).unwrap();
assert_abs_diff_eq!(state_at_midpoint2[0], 90.0, epsilon = 1e-10);
assert_abs_diff_eq!(state_at_midpoint2[1], 180.0, epsilon = 1e-10);
assert_abs_diff_eq!(state_at_midpoint2[2], 270.0, epsilon = 1e-10);
assert_abs_diff_eq!(state_at_midpoint2[3], 360.0, epsilon = 1e-10);
assert_abs_diff_eq!(state_at_midpoint2[4], 450.0, epsilon = 1e-10);
assert_abs_diff_eq!(state_at_midpoint2[5], 540.0, epsilon = 1e-10);
let before_t0 = t0 - 10.0;
assert!(traj.interpolate_linear(&before_t0).is_err());
let after_t2 = t2 + 10.0;
assert!(traj.interpolate_linear(&after_t2).is_err());
let single_epoch = vec![t0];
let single_state = vec![DVector::from_vec(vec![
100.0, 200.0, 300.0, 400.0, 500.0, 600.0,
])];
let single_traj = DTrajectory::from_data(single_epoch, single_state).unwrap();
let state_single = single_traj.interpolate_linear(&t0).unwrap();
assert_abs_diff_eq!(state_single[0], 100.0, epsilon = 1e-10);
assert_abs_diff_eq!(state_single[1], 200.0, epsilon = 1e-10);
assert_abs_diff_eq!(state_single[2], 300.0, epsilon = 1e-10);
assert_abs_diff_eq!(state_single[3], 400.0, epsilon = 1e-10);
assert_abs_diff_eq!(state_single[4], 500.0, epsilon = 1e-10);
assert_abs_diff_eq!(state_single[5], 600.0, epsilon = 1e-10);
let different_epoch = t0 + 10.0;
assert!(single_traj.interpolate_linear(&different_epoch).is_err());
let empty_traj = DTrajectory::new(6);
assert!(empty_traj.interpolate_linear(&t0).is_err());
}
#[test]
fn test_dtrajectory_interpolatable_interpolate() {
let t0 = Epoch::from_datetime(2023, 1, 1, 12, 0, 0.0, 0.0, TimeSystem::UTC);
let t1 = t0 + 60.0;
let t2 = t0 + 120.0;
let epochs = vec![t0, t1, t2];
let states = vec![
DVector::from_vec(vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0]),
DVector::from_vec(vec![60.0, 120.0, 180.0, 240.0, 300.0, 360.0]),
DVector::from_vec(vec![120.0, 240.0, 360.0, 480.0, 600.0, 720.0]),
];
let traj = DTrajectory::from_data(epochs, states).unwrap();
let t0_plus_30 = t0 + 30.0;
let state_interpolate = traj.interpolate(&t0_plus_30).unwrap();
let state_interpolate_linear = traj.interpolate_linear(&t0_plus_30).unwrap();
for i in 0..6 {
assert_abs_diff_eq!(
state_interpolate[i],
state_interpolate_linear[i],
epsilon = 1e-10
);
}
}
#[test]
fn test_dtrajectory_interpolate_before_start() {
let t0 = Epoch::from_datetime(2023, 1, 1, 12, 0, 0.0, 0.0, TimeSystem::UTC);
let t1 = t0 + 60.0;
let t2 = t0 + 120.0;
let epochs = vec![t0, t1, t2];
let states = vec![
DVector::from_vec(vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0]),
DVector::from_vec(vec![60.0, 120.0, 180.0, 240.0, 300.0, 360.0]),
DVector::from_vec(vec![120.0, 240.0, 360.0, 480.0, 600.0, 720.0]),
];
let traj = DTrajectory::from_data(epochs, states).unwrap();
let before_start = t0 - 10.0;
let result = traj.interpolate_linear(&before_start);
assert!(result.is_err());
match result {
Err(BraheError::OutOfBoundsError(_)) => {} _ => panic!("Expected OutOfBoundsError for interpolation before start"),
}
let result = traj.interpolate(&before_start);
assert!(result.is_err());
match result {
Err(BraheError::OutOfBoundsError(_)) => {} _ => panic!("Expected OutOfBoundsError for interpolation before start"),
}
}
#[test]
fn test_dtrajectory_interpolate_after_end() {
let t0 = Epoch::from_datetime(2023, 1, 1, 12, 0, 0.0, 0.0, TimeSystem::UTC);
let t1 = t0 + 60.0;
let t2 = t0 + 120.0;
let epochs = vec![t0, t1, t2];
let states = vec![
DVector::from_vec(vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0]),
DVector::from_vec(vec![60.0, 120.0, 180.0, 240.0, 300.0, 360.0]),
DVector::from_vec(vec![120.0, 240.0, 360.0, 480.0, 600.0, 720.0]),
];
let traj = DTrajectory::from_data(epochs, states).unwrap();
let after_end = t0 + 130.0;
let result = traj.interpolate_linear(&after_end);
assert!(result.is_err());
match result {
Err(BraheError::OutOfBoundsError(_)) => {} _ => panic!("Expected OutOfBoundsError for interpolation after end"),
}
let result = traj.interpolate(&after_end);
assert!(result.is_err());
match result {
Err(BraheError::OutOfBoundsError(_)) => {} _ => panic!("Expected OutOfBoundsError for interpolation after end"),
}
}
#[test]
fn test_dtrajectory_covariance_interpolation_config() {
let traj = DTrajectory::new(6);
assert_eq!(
traj.get_covariance_interpolation_method(),
CovarianceInterpolationMethod::TwoWasserstein
);
let traj = DTrajectory::new(6)
.with_covariance_interpolation_method(CovarianceInterpolationMethod::MatrixSquareRoot);
assert_eq!(
traj.get_covariance_interpolation_method(),
CovarianceInterpolationMethod::MatrixSquareRoot
);
let mut traj = DTrajectory::new(6);
traj.set_covariance_interpolation_method(CovarianceInterpolationMethod::MatrixSquareRoot);
assert_eq!(
traj.get_covariance_interpolation_method(),
CovarianceInterpolationMethod::MatrixSquareRoot
);
traj.set_covariance_interpolation_method(CovarianceInterpolationMethod::TwoWasserstein);
assert_eq!(
traj.get_covariance_interpolation_method(),
CovarianceInterpolationMethod::TwoWasserstein
);
}
#[test]
fn test_dtrajectory_covariance_interpolation_methods() {
let t0 = Epoch::from_jd(2451545.0, TimeSystem::UTC);
let t1 = t0 + 60.0;
let state1 = DVector::from_vec(vec![7000e3, 0.0, 0.0, 0.0, 7.5e3, 0.0]);
let state2 = DVector::from_vec(vec![7100e3, 0.0, 0.0, 0.0, 7.6e3, 0.0]);
let cov1 =
DMatrix::from_diagonal(&DVector::from_vec(vec![100.0, 100.0, 100.0, 1.0, 1.0, 1.0]));
let cov2 =
DMatrix::from_diagonal(&DVector::from_vec(vec![200.0, 200.0, 200.0, 2.0, 2.0, 2.0]));
let mut traj = DTrajectory::new(6);
traj.enable_covariance_storage();
traj.add(t0, state1);
traj.add(t1, state2);
traj.set_covariance_at(0, cov1);
traj.set_covariance_at(1, cov2);
traj.set_covariance_interpolation_method(CovarianceInterpolationMethod::MatrixSquareRoot);
let t_mid = t0 + 30.0;
let cov_sqrt = traj.covariance_at(t_mid).unwrap();
for i in 0..6 {
assert!(cov_sqrt[(i, i)] > 0.0);
for j in 0..6 {
assert_abs_diff_eq!(cov_sqrt[(i, j)], cov_sqrt[(j, i)], epsilon = 1e-10);
}
}
assert!(cov_sqrt[(0, 0)] > 100.0 && cov_sqrt[(0, 0)] < 200.0);
traj.set_covariance_interpolation_method(CovarianceInterpolationMethod::TwoWasserstein);
let cov_wasserstein = traj.covariance_at(t_mid).unwrap();
for i in 0..6 {
assert!(cov_wasserstein[(i, i)] > 0.0);
for j in 0..6 {
assert_abs_diff_eq!(
cov_wasserstein[(i, j)],
cov_wasserstein[(j, i)],
epsilon = 1e-10
);
}
}
assert!(cov_wasserstein[(0, 0)] > 100.0 && cov_wasserstein[(0, 0)] < 200.0);
assert_abs_diff_eq!(cov_sqrt[(0, 0)], cov_wasserstein[(0, 0)], epsilon = 1e-6);
}
#[test]
fn test_dtrajectory_covariance_at_exact_epochs() {
let t0 = Epoch::from_jd(2451545.0, TimeSystem::UTC);
let t1 = t0 + 60.0;
let state1 = DVector::from_vec(vec![7000e3, 0.0, 0.0, 0.0, 7.5e3, 0.0]);
let state2 = DVector::from_vec(vec![7100e3, 0.0, 0.0, 0.0, 7.6e3, 0.0]);
let cov1 = DMatrix::identity(6, 6) * 100.0;
let cov2 = DMatrix::identity(6, 6) * 200.0;
let mut traj = DTrajectory::new(6);
traj.enable_covariance_storage();
traj.add(t0, state1);
traj.add(t1, state2);
traj.set_covariance_at(0, cov1);
traj.set_covariance_at(1, cov2);
let result = traj.covariance_at(t0).unwrap();
assert_abs_diff_eq!(result[(0, 0)], 100.0, epsilon = 1e-10);
let result = traj.covariance_at(t1).unwrap();
assert_abs_diff_eq!(result[(0, 0)], 200.0, epsilon = 1e-10);
}
#[test]
fn test_dtrajectory_enable_stm_storage() {
let mut traj = create_test_trajectory();
assert!(traj.stms.is_none());
traj.enable_stm_storage();
assert!(traj.stms.is_some());
let stms = traj.stms.as_ref().unwrap();
assert_eq!(stms.len(), 3);
for stm in stms {
assert_eq!(stm.nrows(), 6);
assert_eq!(stm.ncols(), 6);
for i in 0..6 {
for j in 0..6 {
if i == j {
assert_abs_diff_eq!(stm[(i, j)], 1.0, epsilon = 1e-10);
} else {
assert_abs_diff_eq!(stm[(i, j)], 0.0, epsilon = 1e-10);
}
}
}
}
}
#[test]
fn test_dtrajectory_enable_stm_storage_idempotent() {
let mut traj = create_test_trajectory();
traj.enable_stm_storage();
traj.set_stm_at(0, DMatrix::from_element(6, 6, 2.0));
traj.enable_stm_storage();
let stm = traj.stm_at_idx(0).unwrap();
assert_abs_diff_eq!(stm[(0, 0)], 2.0, epsilon = 1e-10);
}
#[test]
fn test_dtrajectory_set_stm_at() {
let mut traj = create_test_trajectory();
traj.enable_stm_storage();
let custom_stm = DMatrix::from_element(6, 6, 5.0);
traj.set_stm_at(1, custom_stm.clone());
let result = traj.stm_at_idx(1).unwrap();
assert_abs_diff_eq!(result[(0, 0)], 5.0, epsilon = 1e-10);
assert_abs_diff_eq!(result[(3, 3)], 5.0, epsilon = 1e-10);
}
#[test]
fn test_dtrajectory_set_stm_at_auto_enables() {
let mut traj = create_test_trajectory();
assert!(traj.stms.is_none());
let custom_stm = DMatrix::from_element(6, 6, 3.0);
traj.set_stm_at(0, custom_stm);
assert!(traj.stms.is_some());
let stm = traj.stm_at_idx(0).unwrap();
assert_abs_diff_eq!(stm[(0, 0)], 3.0, epsilon = 1e-10);
let stm1 = traj.stm_at_idx(1).unwrap();
assert_abs_diff_eq!(stm1[(0, 0)], 1.0, epsilon = 1e-10);
assert_abs_diff_eq!(stm1[(0, 1)], 0.0, epsilon = 1e-10);
}
#[test]
#[should_panic(expected = "STM dimensions")]
fn test_dtrajectory_set_stm_at_dimension_mismatch() {
let mut traj = create_test_trajectory();
traj.enable_stm_storage();
let wrong_stm = DMatrix::identity(3, 3);
traj.set_stm_at(0, wrong_stm);
}
#[test]
#[should_panic(expected = "out of bounds")]
fn test_dtrajectory_set_stm_at_out_of_bounds() {
let mut traj = create_test_trajectory();
traj.enable_stm_storage();
let stm = DMatrix::identity(6, 6);
traj.set_stm_at(10, stm); }
#[test]
fn test_dtrajectory_stm_at_idx() {
let mut traj = create_test_trajectory();
traj.enable_stm_storage();
let custom_stm = DMatrix::from_fn(6, 6, |i, j| (i * 6 + j) as f64);
traj.set_stm_at(2, custom_stm);
let result = traj.stm_at_idx(2).unwrap();
assert_abs_diff_eq!(result[(0, 0)], 0.0, epsilon = 1e-10);
assert_abs_diff_eq!(result[(0, 1)], 1.0, epsilon = 1e-10);
assert_abs_diff_eq!(result[(1, 0)], 6.0, epsilon = 1e-10);
assert_abs_diff_eq!(result[(5, 5)], 35.0, epsilon = 1e-10);
}
#[test]
fn test_dtrajectory_stm_at_idx_no_storage() {
let traj = create_test_trajectory();
assert!(traj.stm_at_idx(0).is_none());
assert!(traj.stm_at_idx(1).is_none());
}
#[test]
fn test_dtrajectory_stm_at_interpolation() {
let mut traj = create_test_trajectory();
traj.enable_stm_storage();
let stm0 = DMatrix::from_element(6, 6, 10.0);
let stm1 = DMatrix::from_element(6, 6, 20.0);
traj.set_stm_at(0, stm0);
traj.set_stm_at(1, stm1);
let t0 = traj.epochs[0];
let t1 = traj.epochs[1];
let mid = t0 + (t1 - t0) / 2.0;
let result = traj.stm_at(mid).unwrap();
assert_abs_diff_eq!(result[(0, 0)], 15.0, epsilon = 1e-10);
assert_abs_diff_eq!(result[(3, 3)], 15.0, epsilon = 1e-10);
}
#[test]
fn test_dtrajectory_stm_dimensions() {
let traj = DTrajectory::new(6);
let dims = traj.stm_dimensions();
assert_eq!(dims, (6, 6));
let traj = DTrajectory::new(9);
let dims = traj.stm_dimensions();
assert_eq!(dims, (9, 9));
}
#[test]
fn test_dtrajectory_enable_sensitivity_storage() {
let mut traj = create_test_trajectory();
assert!(traj.sensitivities.is_none());
assert!(traj.sensitivity_dimension.is_none());
traj.enable_sensitivity_storage(3);
assert!(traj.sensitivities.is_some());
assert_eq!(traj.sensitivity_dimension, Some((6, 3)));
let sensitivities = traj.sensitivities.as_ref().unwrap();
assert_eq!(sensitivities.len(), 3);
for sens in sensitivities {
assert_eq!(sens.nrows(), 6);
assert_eq!(sens.ncols(), 3);
for i in 0..6 {
for j in 0..3 {
assert_abs_diff_eq!(sens[(i, j)], 0.0, epsilon = 1e-10);
}
}
}
}
#[test]
#[should_panic(expected = "Parameter dimension must be > 0")]
fn test_dtrajectory_enable_sensitivity_storage_zero_param() {
let mut traj = create_test_trajectory();
traj.enable_sensitivity_storage(0); }
#[test]
fn test_dtrajectory_set_sensitivity_at() {
let mut traj = create_test_trajectory();
traj.enable_sensitivity_storage(2);
let custom_sens = DMatrix::from_element(6, 2, 7.0);
traj.set_sensitivity_at(1, custom_sens);
let result = traj.sensitivity_at_idx(1).unwrap();
assert_abs_diff_eq!(result[(0, 0)], 7.0, epsilon = 1e-10);
assert_abs_diff_eq!(result[(5, 1)], 7.0, epsilon = 1e-10);
}
#[test]
fn test_dtrajectory_set_sensitivity_at_auto_enables() {
let mut traj = create_test_trajectory();
assert!(traj.sensitivities.is_none());
let custom_sens = DMatrix::from_element(6, 4, 9.0);
traj.set_sensitivity_at(0, custom_sens);
assert!(traj.sensitivities.is_some());
assert_eq!(traj.sensitivity_dimensions(), Some((6, 4)));
let sens = traj.sensitivity_at_idx(0).unwrap();
assert_abs_diff_eq!(sens[(0, 0)], 9.0, epsilon = 1e-10);
let sens1 = traj.sensitivity_at_idx(1).unwrap();
assert_abs_diff_eq!(sens1[(0, 0)], 0.0, epsilon = 1e-10);
}
#[test]
#[should_panic(expected = "row count")]
fn test_dtrajectory_set_sensitivity_at_row_mismatch() {
let mut traj = create_test_trajectory();
traj.enable_sensitivity_storage(2);
let wrong_sens = DMatrix::from_element(3, 2, 1.0);
traj.set_sensitivity_at(0, wrong_sens);
}
#[test]
#[should_panic(expected = "column count")]
fn test_dtrajectory_set_sensitivity_at_col_mismatch() {
let mut traj = create_test_trajectory();
traj.enable_sensitivity_storage(2);
let wrong_sens = DMatrix::from_element(6, 5, 1.0);
traj.set_sensitivity_at(0, wrong_sens);
}
#[test]
fn test_dtrajectory_sensitivity_at_idx() {
let mut traj = create_test_trajectory();
traj.enable_sensitivity_storage(2);
let custom_sens = DMatrix::from_fn(6, 2, |i, j| (i * 2 + j) as f64);
traj.set_sensitivity_at(2, custom_sens);
let result = traj.sensitivity_at_idx(2).unwrap();
assert_abs_diff_eq!(result[(0, 0)], 0.0, epsilon = 1e-10);
assert_abs_diff_eq!(result[(0, 1)], 1.0, epsilon = 1e-10);
assert_abs_diff_eq!(result[(1, 0)], 2.0, epsilon = 1e-10);
assert_abs_diff_eq!(result[(5, 1)], 11.0, epsilon = 1e-10);
}
#[test]
fn test_dtrajectory_sensitivity_at_idx_no_storage() {
let traj = create_test_trajectory();
assert!(traj.sensitivity_at_idx(0).is_none());
assert!(traj.sensitivity_at_idx(1).is_none());
}
#[test]
fn test_dtrajectory_sensitivity_at_interpolation() {
let mut traj = create_test_trajectory();
traj.enable_sensitivity_storage(2);
let sens0 = DMatrix::from_element(6, 2, 100.0);
let sens1 = DMatrix::from_element(6, 2, 200.0);
traj.set_sensitivity_at(0, sens0);
traj.set_sensitivity_at(1, sens1);
let t0 = traj.epochs[0];
let t1 = traj.epochs[1];
let mid = t0 + (t1 - t0) / 2.0;
let result = traj.sensitivity_at(mid).unwrap();
assert_abs_diff_eq!(result[(0, 0)], 150.0, epsilon = 1e-10);
assert_abs_diff_eq!(result[(5, 1)], 150.0, epsilon = 1e-10);
}
#[test]
fn test_dtrajectory_sensitivity_dimensions() {
let traj = DTrajectory::new(6);
assert_eq!(traj.sensitivity_dimensions(), None);
let mut traj = DTrajectory::new(6);
traj.enable_sensitivity_storage(4);
assert_eq!(traj.sensitivity_dimensions(), Some((6, 4)));
let mut traj = DTrajectory::new(9);
traj.enable_sensitivity_storage(2);
assert_eq!(traj.sensitivity_dimensions(), Some((9, 2)));
}
#[test]
fn test_dtrajectory_add_full_state_only() {
let mut traj = DTrajectory::new(6);
let epoch = Epoch::from_jd(2451545.0, TimeSystem::UTC);
let state = DVector::from_vec(vec![7000e3, 0.0, 0.0, 0.0, 7.5e3, 0.0]);
traj.add_full(epoch, state.clone(), None, None, None);
assert_eq!(traj.len(), 1);
assert!(traj.covariances.is_none());
assert!(traj.stms.is_none());
assert!(traj.sensitivities.is_none());
let (e, s) = traj.get(0).unwrap();
assert_eq!(e, epoch);
assert_abs_diff_eq!(s[0], 7000e3, epsilon = 1.0);
}
#[test]
fn test_dtrajectory_add_full_with_covariance() {
let mut traj = DTrajectory::new(6);
let epoch = Epoch::from_jd(2451545.0, TimeSystem::UTC);
let state = DVector::from_vec(vec![7000e3, 0.0, 0.0, 0.0, 7.5e3, 0.0]);
let cov = DMatrix::identity(6, 6) * 100.0;
traj.add_full(epoch, state, Some(cov), None, None);
assert_eq!(traj.len(), 1);
assert!(traj.covariances.is_some());
assert!(traj.stms.is_none());
assert!(traj.sensitivities.is_none());
let result_cov = traj.covariances.as_ref().unwrap()[0].clone();
assert_abs_diff_eq!(result_cov[(0, 0)], 100.0, epsilon = 1e-10);
}
#[test]
fn test_dtrajectory_add_full_with_stm() {
let mut traj = DTrajectory::new(6);
let epoch = Epoch::from_jd(2451545.0, TimeSystem::UTC);
let state = DVector::from_vec(vec![7000e3, 0.0, 0.0, 0.0, 7.5e3, 0.0]);
let stm = DMatrix::from_element(6, 6, 2.0);
traj.add_full(epoch, state, None, Some(stm), None);
assert_eq!(traj.len(), 1);
assert!(traj.covariances.is_none());
assert!(traj.stms.is_some());
assert!(traj.sensitivities.is_none());
let result_stm = traj.stms.as_ref().unwrap()[0].clone();
assert_abs_diff_eq!(result_stm[(0, 0)], 2.0, epsilon = 1e-10);
}
#[test]
fn test_dtrajectory_add_full_with_sensitivity() {
let mut traj = DTrajectory::new(6);
let epoch = Epoch::from_jd(2451545.0, TimeSystem::UTC);
let state = DVector::from_vec(vec![7000e3, 0.0, 0.0, 0.0, 7.5e3, 0.0]);
let sens = DMatrix::from_element(6, 3, 5.0);
traj.add_full(epoch, state, None, None, Some(sens));
assert_eq!(traj.len(), 1);
assert!(traj.covariances.is_none());
assert!(traj.stms.is_none());
assert!(traj.sensitivities.is_some());
assert_eq!(traj.sensitivity_dimensions(), Some((6, 3)));
let result_sens = traj.sensitivities.as_ref().unwrap()[0].clone();
assert_abs_diff_eq!(result_sens[(0, 0)], 5.0, epsilon = 1e-10);
}
#[test]
fn test_dtrajectory_add_full_all_matrices() {
let mut traj = DTrajectory::new(6);
let epoch = Epoch::from_jd(2451545.0, TimeSystem::UTC);
let state = DVector::from_vec(vec![7000e3, 0.0, 0.0, 0.0, 7.5e3, 0.0]);
let cov = DMatrix::identity(6, 6) * 100.0;
let stm = DMatrix::from_element(6, 6, 2.0);
let sens = DMatrix::from_element(6, 3, 5.0);
traj.add_full(epoch, state, Some(cov), Some(stm), Some(sens));
assert_eq!(traj.len(), 1);
assert!(traj.covariances.is_some());
assert!(traj.stms.is_some());
assert!(traj.sensitivities.is_some());
}
#[test]
fn test_dtrajectory_add_full_maintains_order() {
let mut traj = DTrajectory::new(6);
let t0 = Epoch::from_jd(2451545.0, TimeSystem::UTC);
let t1 = t0 + 60.0;
let t2 = t0 + 120.0;
let state1 = DVector::from_vec(vec![7100e3, 0.0, 0.0, 0.0, 7.5e3, 0.0]);
let state0 = DVector::from_vec(vec![7000e3, 0.0, 0.0, 0.0, 7.5e3, 0.0]);
let state2 = DVector::from_vec(vec![7200e3, 0.0, 0.0, 0.0, 7.5e3, 0.0]);
traj.add_full(t1, state1, None, None, None);
traj.add_full(t0, state0, None, None, None);
traj.add_full(t2, state2, None, None, None);
assert_eq!(traj.len(), 3);
assert_eq!(traj.epochs[0], t0);
assert_eq!(traj.epochs[1], t1);
assert_eq!(traj.epochs[2], t2);
assert_abs_diff_eq!(traj.states[0][0], 7000e3, epsilon = 1.0);
assert_abs_diff_eq!(traj.states[1][0], 7100e3, epsilon = 1.0);
assert_abs_diff_eq!(traj.states[2][0], 7200e3, epsilon = 1.0);
}
#[test]
#[should_panic(expected = "State vector dimension")]
fn test_dtrajectory_add_full_state_dimension_mismatch() {
let mut traj = DTrajectory::new(6);
let epoch = Epoch::from_jd(2451545.0, TimeSystem::UTC);
let wrong_state = DVector::from_vec(vec![7000e3, 0.0, 0.0]);
traj.add_full(epoch, wrong_state, None, None, None);
}
#[test]
#[should_panic(expected = "STM dimensions")]
fn test_dtrajectory_add_full_stm_dimension_mismatch() {
let mut traj = DTrajectory::new(6);
let epoch = Epoch::from_jd(2451545.0, TimeSystem::UTC);
let state = DVector::from_vec(vec![7000e3, 0.0, 0.0, 0.0, 7.5e3, 0.0]);
let wrong_stm = DMatrix::identity(3, 3);
traj.add_full(epoch, state, None, Some(wrong_stm), None);
}
#[test]
#[should_panic(expected = "Sensitivity row count")]
fn test_dtrajectory_add_full_sensitivity_row_mismatch() {
let mut traj = DTrajectory::new(6);
let epoch = Epoch::from_jd(2451545.0, TimeSystem::UTC);
let state = DVector::from_vec(vec![7000e3, 0.0, 0.0, 0.0, 7.5e3, 0.0]);
let wrong_sens = DMatrix::from_element(3, 2, 1.0);
traj.add_full(epoch, state, None, None, Some(wrong_sens));
}
#[test]
#[should_panic(expected = "Sensitivity column count")]
fn test_dtrajectory_add_full_sensitivity_col_mismatch() {
let mut traj = DTrajectory::new(6);
let t0 = Epoch::from_jd(2451545.0, TimeSystem::UTC);
let t1 = t0 + 60.0;
let state = DVector::from_vec(vec![7000e3, 0.0, 0.0, 0.0, 7.5e3, 0.0]);
let sens1 = DMatrix::from_element(6, 2, 1.0);
traj.add_full(t0, state.clone(), None, None, Some(sens1));
let sens2 = DMatrix::from_element(6, 5, 1.0);
traj.add_full(t1, state, None, None, Some(sens2));
}
#[test]
fn test_dtrajectory_epoch_initial() {
let traj = create_test_trajectory();
let initial = traj.epoch_initial();
assert!(initial.is_some());
assert_eq!(initial.unwrap(), traj.epochs[0]);
}
#[test]
fn test_dtrajectory_epoch_initial_empty() {
let traj = DTrajectory::new(6);
assert!(traj.epoch_initial().is_none());
}
#[test]
fn test_dtrajectory_find_surrounding_indices() {
let traj = create_test_trajectory();
let t0 = traj.epochs[0];
let t1 = traj.epochs[1];
let mid = t0 + (t1 - t0) / 2.0;
let result = traj.find_surrounding_indices(mid);
assert!(result.is_some());
let (idx0, idx1) = result.unwrap();
assert_eq!(idx0, 0);
assert_eq!(idx1, 1);
}
#[test]
fn test_dtrajectory_find_surrounding_indices_empty() {
let traj = DTrajectory::new(6);
let epoch = Epoch::from_jd(2451545.0, TimeSystem::UTC);
assert!(traj.find_surrounding_indices(epoch).is_none());
}
#[test]
fn test_dtrajectory_find_surrounding_indices_before_start() {
let traj = create_test_trajectory();
let before = traj.epochs[0] - 100.0;
assert!(traj.find_surrounding_indices(before).is_none());
}
#[test]
fn test_dtrajectory_find_surrounding_indices_after_end() {
let traj = create_test_trajectory();
let after = traj.epochs[2] + 100.0;
assert!(traj.find_surrounding_indices(after).is_none());
}
#[test]
fn test_dtrajectory_eviction_keep_count_with_covariances() {
let mut traj = DTrajectory::new(6).with_eviction_policy_max_size(3);
traj.enable_covariance_storage();
let t0 = Epoch::from_datetime(2023, 1, 1, 12, 0, 0.0, 0.0, TimeSystem::UTC);
for i in 0..5 {
let epoch = t0 + (i as f64 * 60.0);
let state =
DVector::from_vec(vec![7000e3 + i as f64 * 1000.0, 0.0, 0.0, 0.0, 7.5e3, 0.0]);
let cov = DMatrix::identity(6, 6) * (i as f64 * 10.0);
traj.add_with_covariance(epoch, state, cov);
}
assert_eq!(traj.len(), 3);
assert_eq!(traj.covariances.as_ref().unwrap().len(), 3);
let cov = traj.covariances.as_ref().unwrap()[0].clone();
assert_abs_diff_eq!(cov[(0, 0)], 20.0, epsilon = 1e-10);
}
#[test]
fn test_dtrajectory_eviction_keep_count_with_stms() {
let mut traj = DTrajectory::new(6).with_eviction_policy_max_size(3);
traj.enable_stm_storage();
let t0 = Epoch::from_datetime(2023, 1, 1, 12, 0, 0.0, 0.0, TimeSystem::UTC);
for i in 0..5 {
let epoch = t0 + (i as f64 * 60.0);
let state =
DVector::from_vec(vec![7000e3 + i as f64 * 1000.0, 0.0, 0.0, 0.0, 7.5e3, 0.0]);
traj.add(epoch, state);
let stm = DMatrix::from_element(6, 6, i as f64);
traj.set_stm_at(traj.len() - 1, stm);
}
assert_eq!(traj.len(), 3);
assert_eq!(traj.stms.as_ref().unwrap().len(), 3);
let stm = traj.stms.as_ref().unwrap()[0].clone();
assert_abs_diff_eq!(stm[(0, 0)], 2.0, epsilon = 1e-10);
}
#[test]
fn test_dtrajectory_eviction_keep_count_with_sensitivities() {
let mut traj = DTrajectory::new(6).with_eviction_policy_max_size(3);
traj.enable_sensitivity_storage(2);
let t0 = Epoch::from_datetime(2023, 1, 1, 12, 0, 0.0, 0.0, TimeSystem::UTC);
for i in 0..5 {
let epoch = t0 + (i as f64 * 60.0);
let state =
DVector::from_vec(vec![7000e3 + i as f64 * 1000.0, 0.0, 0.0, 0.0, 7.5e3, 0.0]);
traj.add(epoch, state);
let sens = DMatrix::from_element(6, 2, i as f64 * 10.0);
traj.set_sensitivity_at(traj.len() - 1, sens);
}
assert_eq!(traj.len(), 3);
assert_eq!(traj.sensitivities.as_ref().unwrap().len(), 3);
let sens = traj.sensitivities.as_ref().unwrap()[0].clone();
assert_abs_diff_eq!(sens[(0, 0)], 20.0, epsilon = 1e-10);
}
#[test]
fn test_dtrajectory_eviction_keep_count_all_data() {
let mut traj = DTrajectory::new(6).with_eviction_policy_max_size(2);
let t0 = Epoch::from_datetime(2023, 1, 1, 12, 0, 0.0, 0.0, TimeSystem::UTC);
for i in 0..4 {
let epoch = t0 + (i as f64 * 60.0);
let state =
DVector::from_vec(vec![7000e3 + i as f64 * 1000.0, 0.0, 0.0, 0.0, 7.5e3, 0.0]);
let cov = DMatrix::identity(6, 6) * (i as f64);
let stm = DMatrix::from_element(6, 6, i as f64 * 2.0);
let sens = DMatrix::from_element(6, 3, i as f64 * 3.0);
traj.add_full(epoch, state, Some(cov), Some(stm), Some(sens));
}
assert_eq!(traj.len(), 2);
assert_eq!(traj.covariances.as_ref().unwrap().len(), 2);
assert_eq!(traj.stms.as_ref().unwrap().len(), 2);
assert_eq!(traj.sensitivities.as_ref().unwrap().len(), 2);
assert_abs_diff_eq!(
traj.covariances.as_ref().unwrap()[0][(0, 0)],
2.0,
epsilon = 1e-10
);
assert_abs_diff_eq!(traj.stms.as_ref().unwrap()[0][(0, 0)], 4.0, epsilon = 1e-10);
assert_abs_diff_eq!(
traj.sensitivities.as_ref().unwrap()[0][(0, 0)],
6.0,
epsilon = 1e-10
);
}
#[test]
fn test_dtrajectory_eviction_keep_within_duration_with_covariances() {
let mut traj = DTrajectory::new(6).with_eviction_policy_max_age(150.0);
traj.enable_covariance_storage();
let t0 = Epoch::from_datetime(2023, 1, 1, 12, 0, 0.0, 0.0, TimeSystem::UTC);
for i in 0..5 {
let epoch = t0 + (i as f64 * 60.0); let state =
DVector::from_vec(vec![7000e3 + i as f64 * 1000.0, 0.0, 0.0, 0.0, 7.5e3, 0.0]);
let cov = DMatrix::identity(6, 6) * (i as f64 * 10.0);
traj.add_with_covariance(epoch, state, cov);
}
assert_eq!(traj.len(), 3);
assert_eq!(traj.covariances.as_ref().unwrap().len(), 3);
}
#[test]
fn test_dtrajectory_eviction_keep_within_duration_with_stms() {
let mut traj = DTrajectory::new(6).with_eviction_policy_max_age(150.0);
let t0 = Epoch::from_datetime(2023, 1, 1, 12, 0, 0.0, 0.0, TimeSystem::UTC);
for i in 0..5 {
let epoch = t0 + (i as f64 * 60.0); let state =
DVector::from_vec(vec![7000e3 + i as f64 * 1000.0, 0.0, 0.0, 0.0, 7.5e3, 0.0]);
let stm = DMatrix::from_element(6, 6, i as f64);
traj.add_full(epoch, state, None, Some(stm), None);
}
assert_eq!(traj.len(), 3);
assert_eq!(traj.stms.as_ref().unwrap().len(), 3);
}
#[test]
fn test_dtrajectory_eviction_keep_within_duration_with_sensitivities() {
let mut traj = DTrajectory::new(6).with_eviction_policy_max_age(150.0);
let t0 = Epoch::from_datetime(2023, 1, 1, 12, 0, 0.0, 0.0, TimeSystem::UTC);
for i in 0..5 {
let epoch = t0 + (i as f64 * 60.0); let state =
DVector::from_vec(vec![7000e3 + i as f64 * 1000.0, 0.0, 0.0, 0.0, 7.5e3, 0.0]);
let sens = DMatrix::from_element(6, 2, i as f64);
traj.add_full(epoch, state, None, None, Some(sens));
}
assert_eq!(traj.len(), 3);
assert_eq!(traj.sensitivities.as_ref().unwrap().len(), 3);
}
#[test]
fn test_dtrajectory_with_interpolation_method_builder_pattern() {
let traj = DTrajectory::new(6).with_interpolation_method(InterpolationMethod::Linear);
assert_eq!(traj.get_interpolation_method(), InterpolationMethod::Linear);
}
#[test]
fn test_dtrajectory_with_interpolation_method_lagrange() {
let traj = DTrajectory::new(6)
.with_interpolation_method(InterpolationMethod::Lagrange { degree: 5 });
match traj.get_interpolation_method() {
InterpolationMethod::Lagrange { degree } => assert_eq!(degree, 5),
_ => panic!("Expected Lagrange interpolation method"),
}
}
}